diff --git a/Cargo.lock b/Cargo.lock index cac23e44b..864a3b153 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2870,6 +2870,7 @@ dependencies = [ "ctor", "deadpool-sqlite", "glob", + "itertools 0.11.0", "matrix-sdk-base", "matrix-sdk-crypto", "matrix-sdk-store-encryption", diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index cf6a5effd..cf2055847 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -21,12 +21,13 @@ state-store = [] [dependencies] async-trait = { workspace = true } deadpool-sqlite = "0.5.0" +itertools = { workspace = true } matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base" } matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true } matrix-sdk-store-encryption = { version = "0.2.0", path = "../matrix-sdk-store-encryption" } rmp-serde = "1.1.1" ruma = { workspace = true } -rusqlite = "0.28.0" +rusqlite = { version = "0.28.0", features = ["limits"] } serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } diff --git a/crates/matrix-sdk-sqlite/src/error.rs b/crates/matrix-sdk-sqlite/src/error.rs index 0e4544067..84d89dc5b 100644 --- a/crates/matrix-sdk-sqlite/src/error.rs +++ b/crates/matrix-sdk-sqlite/src/error.rs @@ -69,22 +69,34 @@ pub enum OpenStoreError { pub enum Error { #[error(transparent)] Sqlite(rusqlite::Error), + + #[error("Failed to compute the maximum variable number from {0}")] + SqliteMaximumVariableNumber(i32), + #[error(transparent)] Pool(PoolError), + #[error(transparent)] Encode(rmp_serde::encode::Error), + #[error(transparent)] Decode(rmp_serde::decode::Error), + #[error(transparent)] Json(#[from] serde_json::Error), + #[error(transparent)] Encryption(matrix_sdk_store_encryption::Error), + #[error("can't save/load sessions or group sessions in the store before an account is stored")] AccountUnset, + #[error(transparent)] Pickle(#[from] vodozemac::PickleError), + #[error("An object failed to be decrypted while unpickling")] Unpickle, + #[error("Redaction failed: {0}")] Redaction(#[source] ruma::canonical_json::RedactionError), } diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index c00ef0025..8c74f0836 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -1,13 +1,17 @@ use std::{ borrow::Cow, + cmp::min, collections::{BTreeMap, BTreeSet}, - fmt, iter, + fmt, + future::Future, + iter, path::{Path, PathBuf}, sync::Arc, }; use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime}; +use itertools::Itertools; use matrix_sdk_base::{ deserialized_responses::RawAnySyncOrStrippedState, media::{MediaRequest, UniqueKey}, @@ -27,7 +31,7 @@ use ruma::{ serde::Raw, CanonicalJsonObject, EventId, OwnedEventId, OwnedUserId, RoomId, RoomVersionId, UserId, }; -use rusqlite::{OptionalExtension, Transaction}; +use rusqlite::{limits::Limit, OptionalExtension, Transaction}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::fs; use tracing::{debug, warn}; @@ -550,14 +554,19 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { } async fn get_kv_blobs(&self, keys: Vec) -> Result>> { - let sql_params = vec!["?"; keys.len()].join(", "); - let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})"); + let keys_length = keys.len(); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(keys))?.mapped(|row| row.get(0)).collect() + chunk_large_query_over(keys, Some(keys_length), |keys| { + let sql_params = repeat_vars(keys.len()); + let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})"); + + let params = rusqlite::params_from_iter(keys); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| row.get(0)).collect() }) - .await?) + }) + .await } async fn set_kv_blob(&self, key: Key, value: Vec) -> Result<()>; @@ -575,16 +584,17 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { }) .await?) } else { - let sql_params = vec!["?"; states.len()].join(", "); - let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})"); + chunk_large_query_over(states, None, |states| { + let sql_params = repeat_vars(states.len()); + let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})"); - Ok(self - .prepare(sql, move |mut stmt| { + self.prepare(sql, move |mut stmt| { stmt.query(rusqlite::params_from_iter(states))? .mapped(|row| row.get(0)) .collect() }) - .await?) + }) + .await } } @@ -594,20 +604,22 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { event_type: Key, state_keys: Vec, ) -> Result)>> { - let sql_params = vec!["?"; state_keys.len()].join(", "); - let sql = format!( - "SELECT stripped, data FROM state_event + chunk_large_query_over(state_keys, None, move |state_keys: Vec| { + let sql_params = repeat_vars(state_keys.len()); + let sql = format!( + "SELECT stripped, data FROM state_event WHERE room_id = ? AND event_type = ? AND state_key IN ({sql_params})" - ); - let params = [room_id, event_type].into_iter().chain(state_keys); + ); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))? - .mapped(|row| Ok((row.get(0)?, row.get(1)?))) - .collect() + let params = rusqlite::params_from_iter( + [room_id.clone(), event_type.clone()].into_iter().chain(state_keys), + ); + + self.prepare(sql, |mut stmt| { + stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect() }) - .await?) + }) + .await } async fn get_maybe_stripped_state_events( @@ -633,19 +645,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { room_id: Key, user_ids: Vec, ) -> Result, Vec)>> { - let sql_params = vec!["?"; user_ids.len()].join(", "); - let sql = format!( - "SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})" - ); - let params = iter::once(room_id).chain(user_ids); + let user_ids_length = user_ids.len(); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))? - .mapped(|row| Ok((row.get(0)?, row.get(1)?))) - .collect() + chunk_large_query_over(user_ids, Some(user_ids_length), move |user_ids| { + let sql_params = repeat_vars(user_ids.len()); + let sql = format!( + "SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})" + ); + + let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(user_ids)); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect() }) - .await?) + }) + .await } async fn get_user_ids(&self, room_id: Key, memberships: Vec) -> Result>> { @@ -655,14 +669,18 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { }) .await? } else { - let sql_params = vec!["?"; memberships.len()].join(", "); - let sql = format!( - "SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})" - ); - let params = iter::once(room_id).chain(memberships); + chunk_large_query_over(memberships, None, move |memberships| { + let sql_params = repeat_vars(memberships.len()); + let sql = format!( + "SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})" + ); - self.prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))?.mapped(|row| row.get(0)).collect() + let params = + rusqlite::params_from_iter(iter::once(room_id.clone()).chain(memberships)); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| row.get(0)).collect() + }) }) .await? }; @@ -701,19 +719,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { room_id: Key, names: Vec, ) -> Result, Vec)>> { - let sql_params = vec!["?"; names.len()].join(", "); - let sql = format!( - "SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})" - ); - let params = iter::once(room_id).chain(names); + let names_length = names.len(); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))? - .mapped(|row| Ok((row.get(0)?, row.get(1)?))) - .collect() + chunk_large_query_over(names, Some(names_length), move |names| { + let sql_params = repeat_vars(names.len()); + let sql = format!( + "SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})" + ); + + let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(names)); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect() }) - .await?) + }) + .await } async fn get_user_receipt( @@ -1531,6 +1551,51 @@ struct ReceiptData { user_id: OwnedUserId, } +/// Chunk a large query over some keys. +/// +/// Imagine there is a _dynamic_ query that runs potentially large number of +/// parameters, so much that the maximum number of parameters can be hit. Then, +/// this helper is for you. It will execute the query on chunks of parameters. +async fn chunk_large_query_over( + mut keys_to_chunk: Vec, + result_capacity: Option, + do_query: Query, +) -> Result> +where + Query: Fn(Vec) -> Fut, + Fut: Future, rusqlite::Error>>, +{ + let mut all_results = if let Some(capacity) = result_capacity { + Vec::with_capacity(capacity) + } else { + Vec::new() + }; + + // `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then divide by 2 to + // let space for more static parameters (not part of `keys_to_chunk`). + let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 / 2; + let maximum_chunk_size: usize = maximum_chunk_size + .try_into() + .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?; + + while !keys_to_chunk.is_empty() { + let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size)); + let chunk = keys_to_chunk; + keys_to_chunk = tail; + + all_results.extend(do_query(chunk).await?); + } + + Ok(all_results) +} + +/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated. +fn repeat_vars(count: usize) -> impl fmt::Display { + assert_ne!(count, 0); + + iter::repeat("?").take(count).format(",") +} + #[cfg(test)] mod tests { use std::sync::atomic::{AtomicU32, Ordering::SeqCst};