diff --git a/Cargo.lock b/Cargo.lock index 43e5c1cc7..563feb55b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3203,6 +3203,7 @@ dependencies = [ "matrix-sdk-crypto", "matrix-sdk-store-encryption", "matrix-sdk-test", + "num_cpus", "once_cell", "rmp-serde", "ruma", diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index d1d62dd6d..ac32218dd 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -23,6 +23,7 @@ itertools = { workspace = true } matrix-sdk-base = { workspace = true, optional = true } matrix-sdk-crypto = { workspace = true, optional = true } matrix-sdk-store-encryption = { workspace = true } +num_cpus = "1.16.0" rmp-serde = { workspace = true } ruma = { workspace = true } rusqlite = { version = "0.32.1", features = ["limits"] } diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 805d4562f..0d6358fa2 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -51,9 +51,12 @@ use crate::{ repeat_vars, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt, SqliteKeyValueStoreConnExt, }, - OpenStoreError, + OpenStoreError, SqliteStoreConfig, }; +/// The database name. +const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3"; + /// A sqlite based cryptostore. #[derive(Clone)] pub struct SqliteCryptoStore { @@ -79,12 +82,21 @@ impl SqliteCryptoStore { path: impl AsRef, passphrase: Option<&str>, ) -> Result { - let path = path.as_ref(); - fs::create_dir_all(path).await.map_err(OpenStoreError::CreateDir)?; - let cfg = deadpool_sqlite::Config::new(path.join("matrix-sdk-crypto.sqlite3")); - let pool = cfg.create_pool(Runtime::Tokio1)?; + Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await + } - Self::open_with_pool(pool, passphrase).await + /// Open the sqlite-based crypto store with the config open config. + pub async fn open_with_config(config: SqliteStoreConfig) -> Result { + let SqliteStoreConfig { path, passphrase, pool_config } = config; + + fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?; + + let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME)); + config.pool = Some(pool_config); + + let pool = config.create_pool(Runtime::Tokio1)?; + + Self::open_with_pool(pool, passphrase.as_deref()).await } /// Create a sqlite-based crypto store using the given sqlite database pool. @@ -1421,6 +1433,7 @@ mod tests { use tokio::fs; use super::SqliteCryptoStore; + use crate::SqliteStoreConfig; static TMP_DIR: Lazy = Lazy::new(|| tempdir().unwrap()); @@ -1431,8 +1444,8 @@ mod tests { database: SqliteCryptoStore, } - async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb { - let db_name = "matrix-sdk-crypto.sqlite3"; + fn copy_db(data_path: &str) -> TempDir { + let db_name = super::DATABASE_NAME; let manifest_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("../.."); let database_path = manifest_path.join(data_path).join(db_name); @@ -1443,6 +1456,12 @@ mod tests { // Copy the test database to the tempdir so our test runs are idempotent. std::fs::copy(&database_path, destination).unwrap(); + tmpdir + } + + async fn get_test_db(data_path: &str, passphrase: Option<&str>) -> TestDb { + let tmpdir = copy_db(data_path); + let database = SqliteCryptoStore::open(tmpdir.path(), passphrase) .await .expect("Can't open the test store"); @@ -1450,6 +1469,16 @@ mod tests { TestDb { _dir: tmpdir, database } } + #[async_test] + async fn test_pool_size() { + let store_open_config = + SqliteStoreConfig::new(TMP_DIR.path().join("test_pool_size")).pool_max_size(42); + + let store = SqliteCryptoStore::open_with_config(store_open_config).await.unwrap(); + + assert_eq!(store.pool.status().max_size, 42); + } + /// Test that we didn't regress in our storage layer by loading data from a /// pre-filled database, or in other words use a test vector for this. #[async_test] @@ -1787,6 +1816,7 @@ mod tests { assert_eq!(backup_keys.backup_version.unwrap(), "6"); assert!(backup_keys.decryption_key.is_some()); } + async fn get_store( name: &str, passphrase: Option<&str>, diff --git a/crates/matrix-sdk-sqlite/src/event_cache_store.rs b/crates/matrix-sdk-sqlite/src/event_cache_store.rs index 96858570b..9a14fa591 100644 --- a/crates/matrix-sdk-sqlite/src/event_cache_store.rs +++ b/crates/matrix-sdk-sqlite/src/event_cache_store.rs @@ -47,7 +47,7 @@ use crate::{ repeat_vars, time_to_timestamp, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt, SqliteKeyValueStoreConnExt, SqliteTransactionExt, }, - OpenStoreError, + OpenStoreError, SqliteStoreConfig, }; mod keys { @@ -60,6 +60,9 @@ mod keys { pub const MEDIA: &str = "media"; } +/// The database name. +const DATABASE_NAME: &str = "matrix-sdk-event-cache.sqlite3"; + /// Identifier of the latest database version. /// /// This is used to figure whether the SQLite database requires a migration. @@ -96,9 +99,21 @@ impl SqliteEventCacheStore { path: impl AsRef, passphrase: Option<&str>, ) -> Result { - let pool = create_pool(path.as_ref()).await?; + Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await + } - Self::open_with_pool(pool, passphrase).await + /// Open the sqlite-based event cache store with the config open config. + pub async fn open_with_config(config: SqliteStoreConfig) -> Result { + let SqliteStoreConfig { path, passphrase, pool_config } = config; + + fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?; + + let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME)); + config.pool = Some(pool_config); + + let pool = config.create_pool(Runtime::Tokio1)?; + + Self::open_with_pool(pool, passphrase.as_deref()).await } /// Open an SQLite-based event cache store using the given SQLite database @@ -294,12 +309,6 @@ impl TransactionExtForLinkedChunks for Transaction<'_> { } } -async fn create_pool(path: &Path) -> Result { - fs::create_dir_all(path).await.map_err(OpenStoreError::CreateDir)?; - let cfg = deadpool_sqlite::Config::new(path.join("matrix-sdk-event-cache.sqlite3")); - Ok(cfg.create_pool(Runtime::Tokio1)?) -} - /// Run migrations for the given version of the database. async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> { if version == 0 { @@ -1348,6 +1357,7 @@ fn insert_chunk( #[cfg(test)] mod tests { use std::{ + path::PathBuf, sync::atomic::{AtomicU32, Ordering::SeqCst}, time::Duration, }; @@ -1373,14 +1383,18 @@ mod tests { use tempfile::{tempdir, TempDir}; use super::SqliteEventCacheStore; - use crate::utils::SqliteAsyncConnExt; + use crate::{utils::SqliteAsyncConnExt, SqliteStoreConfig}; static TMP_DIR: Lazy = Lazy::new(|| tempdir().unwrap()); static NUM: AtomicU32 = AtomicU32::new(0); - async fn get_event_cache_store() -> Result { + fn new_event_cache_store_workspace() -> PathBuf { let name = NUM.fetch_add(1, SeqCst).to_string(); - let tmpdir_path = TMP_DIR.path().join(name); + TMP_DIR.path().join(name) + } + + async fn get_event_cache_store() -> Result { + let tmpdir_path = new_event_cache_store_workspace(); tracing::info!("using event cache store @ {}", tmpdir_path.to_str().unwrap()); @@ -1403,6 +1417,16 @@ mod tests { .expect("querying media cache content by last access failed") } + #[async_test] + async fn test_pool_size() { + let tmpdir_path = new_event_cache_store_workspace(); + let store_open_config = SqliteStoreConfig::new(tmpdir_path).pool_max_size(42); + + let store = SqliteEventCacheStore::open_with_config(store_open_config).await.unwrap(); + + assert_eq!(store.pool.status().max_size, 42); + } + #[async_test] async fn test_last_access() { let event_cache_store = get_event_cache_store().await.expect("creating media cache failed"); diff --git a/crates/matrix-sdk-sqlite/src/lib.rs b/crates/matrix-sdk-sqlite/src/lib.rs index ff62c3f7f..75e1bed68 100644 --- a/crates/matrix-sdk-sqlite/src/lib.rs +++ b/crates/matrix-sdk-sqlite/src/lib.rs @@ -24,6 +24,9 @@ mod event_cache_store; #[cfg(feature = "state-store")] mod state_store; mod utils; +use std::path::{Path, PathBuf}; + +use deadpool_sqlite::PoolConfig; #[cfg(feature = "crypto-store")] pub use self::crypto_store::SqliteCryptoStore; @@ -35,3 +38,59 @@ pub use self::state_store::SqliteStateStore; #[cfg(test)] matrix_sdk_test::init_tracing_for_tests!(); + +/// A configuration structure used for opening a store. +pub struct SqliteStoreConfig { + /// Path to the database, without the file name. + path: PathBuf, + /// Passphrase to open the store, if any. + passphrase: Option, + /// The pool configuration for [`deadpool_sqlite`]. + pool_config: PoolConfig, +} + +impl SqliteStoreConfig { + /// Create a new [`SqliteStoreConfig`] with a path representing the + /// directory containing the store database. + pub fn new

(path: P) -> Self + where + P: AsRef, + { + Self { + path: path.as_ref().to_path_buf(), + passphrase: None, + pool_config: PoolConfig::new(num_cpus::get_physical() * 4), + } + } + + /// Define the passphrase if the store is encoded. + pub fn passphrase(mut self, passphrase: Option<&str>) -> Self { + self.passphrase = passphrase.map(|passphrase| passphrase.to_owned()); + self + } + + /// Define the maximum pool size for [`deadpool_sqlite`]. + /// + /// See [`deadpool_sqlite::PoolConfig::max_size`] to learn more. + pub fn pool_max_size(mut self, max_size: usize) -> Self { + self.pool_config.max_size = max_size; + self + } +} + +#[cfg(test)] +mod tests { + use std::path::{Path, PathBuf}; + + use super::SqliteStoreConfig; + + #[test] + fn test_store_open_config() { + let store_open_config = + SqliteStoreConfig::new(Path::new("foo")).passphrase(Some("bar")).pool_max_size(42); + + assert_eq!(store_open_config.path, PathBuf::from("foo")); + assert_eq!(store_open_config.passphrase, Some("bar".to_owned())); + assert_eq!(store_open_config.pool_config.max_size, 42); + } +} diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 1f863f45b..9cf0821f8 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -46,7 +46,7 @@ use crate::{ repeat_vars, Key, SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt, SqliteKeyValueStoreConnExt, }, - OpenStoreError, + OpenStoreError, SqliteStoreConfig, }; mod keys { @@ -64,6 +64,8 @@ mod keys { pub const DEPENDENTS_SEND_QUEUE: &str = "dependent_send_queue_events"; } +const DATABASE_NAME: &str = "matrix-sdk-state.sqlite3"; + /// Identifier of the latest database version. /// /// This is used to figure whether the sqlite database requires a migration. @@ -92,9 +94,21 @@ impl SqliteStateStore { path: impl AsRef, passphrase: Option<&str>, ) -> Result { - let pool = create_pool(path.as_ref()).await?; + Self::open_with_config(SqliteStoreConfig::new(path).passphrase(passphrase)).await + } - Self::open_with_pool(pool, passphrase).await + /// Open the sqlite-based state store with the config open config. + pub async fn open_with_config(config: SqliteStoreConfig) -> Result { + let SqliteStoreConfig { path, passphrase, pool_config } = config; + + fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir)?; + + let mut config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME)); + config.pool = Some(pool_config); + + let pool = config.create_pool(Runtime::Tokio1)?; + + Self::open_with_pool(pool, passphrase.as_deref()).await } /// Create a sqlite-based state store using the given sqlite database pool. @@ -448,12 +462,6 @@ impl SqliteStateStore { } } -async fn create_pool(path: &Path) -> Result { - fs::create_dir_all(path).await.map_err(OpenStoreError::CreateDir)?; - let cfg = deadpool_sqlite::Config::new(path.join("matrix-sdk-state.sqlite3")); - Ok(cfg.create_pool(Runtime::Tokio1)?) -} - /// Initialize the database. async fn init(conn: &SqliteAsyncConn) -> Result<()> { // First turn on WAL mode, this can't be done in the transaction, it fails with @@ -2126,20 +2134,29 @@ mod tests { #[cfg(test)] mod encrypted_tests { - use std::sync::atomic::{AtomicU32, Ordering::SeqCst}; + use std::{ + path::PathBuf, + sync::atomic::{AtomicU32, Ordering::SeqCst}, + }; use matrix_sdk_base::{statestore_integration_tests, StateStore, StoreError}; + use matrix_sdk_test::async_test; use once_cell::sync::Lazy; use tempfile::{tempdir, TempDir}; use super::SqliteStateStore; + use crate::SqliteStoreConfig; static TMP_DIR: Lazy = Lazy::new(|| tempdir().unwrap()); static NUM: AtomicU32 = AtomicU32::new(0); - async fn get_store() -> Result { + fn new_state_store_workspace() -> PathBuf { let name = NUM.fetch_add(1, SeqCst).to_string(); - let tmpdir_path = TMP_DIR.path().join(name); + TMP_DIR.path().join(name) + } + + async fn get_store() -> Result { + let tmpdir_path = new_state_store_workspace(); tracing::info!("using store @ {}", tmpdir_path.to_str().unwrap()); @@ -2148,6 +2165,16 @@ mod encrypted_tests { .unwrap()) } + #[async_test] + async fn test_pool_size() { + let tmpdir_path = new_state_store_workspace(); + let store_open_config = SqliteStoreConfig::new(tmpdir_path).pool_max_size(42); + + let store = SqliteStateStore::open_with_config(store_open_config).await.unwrap(); + + assert_eq!(store.pool.status().max_size, 42); + } + statestore_integration_tests!(); } @@ -2161,6 +2188,7 @@ mod migration_tests { }, }; + use deadpool_sqlite::Runtime; use matrix_sdk_base::{ store::{ChildTransactionId, DependentQueuedRequestKind, SerializableEventContent}, sync::UnreadNotificationsCount, @@ -2179,11 +2207,13 @@ mod migration_tests { use rusqlite::Transaction; use serde_json::json; use tempfile::{tempdir, TempDir}; + use tokio::fs; - use super::{create_pool, init, keys, SqliteStateStore}; + use super::{init, keys, SqliteStateStore, DATABASE_NAME}; use crate::{ error::{Error, Result}, utils::{SqliteAsyncConnExt, SqliteKeyValueStoreAsyncConnExt}, + OpenStoreError, }; static TMP_DIR: Lazy = Lazy::new(|| tempdir().unwrap()); @@ -2196,7 +2226,12 @@ mod migration_tests { } async fn create_fake_db(path: &Path, version: u8) -> Result { - let pool = create_pool(path).await.unwrap(); + fs::create_dir_all(&path).await.map_err(OpenStoreError::CreateDir).unwrap(); + + let config = deadpool_sqlite::Config::new(path.join(DATABASE_NAME)); + // use default pool config + + let pool = config.create_pool(Runtime::Tokio1).unwrap(); let conn = pool.get().await?; init(&conn).await?;