From 380bce0604843758e6ab6b7385e200e24df4be4c Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Thu, 6 Jul 2023 11:56:49 +0200 Subject: [PATCH 1/5] fix(sqlite): Prevent reaching the maximum number of parameters in an SQL statement. This patch introduces a new `chunk_large_query_over` function. Imagine there is a _dynamic_ query that runs potentially large number of parameters, so much that the maximum of parameters can be hit. Then, this function is for you. It will execute the query on chunks of parameters. `chunk_large_query_over` uses `Vec::split_off` to avoid cloning `Key`s as much as possible. It's difficult to use references here because of the async nature of `SqliteObjectExt::prepare`. This patch updates `get_kv_blobs`, `get_room_infos`, `get_maybe_stripped_state_events_for_keys`, `get_profiles`, `get_user_ids`, and `get_display_names` to use this new function. This patch finally adds the `repeat_vars` function to replace in a more efficient way the `vec!["?"; n].join(", ")` pattern. There is less memory allocations. --- crates/matrix-sdk-sqlite/Cargo.toml | 2 +- crates/matrix-sdk-sqlite/src/error.rs | 12 ++ crates/matrix-sdk-sqlite/src/state_store.rs | 175 ++++++++++++++------ 3 files changed, 135 insertions(+), 54 deletions(-) diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index cf6a5effd..bf6135f9e 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -26,7 +26,7 @@ matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional matrix-sdk-store-encryption = { version = "0.2.0", path = "../matrix-sdk-store-encryption" } rmp-serde = "1.1.1" ruma = { workspace = true } -rusqlite = "0.28.0" +rusqlite = { version = "0.28.0", features = ["limits"] } serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } diff --git a/crates/matrix-sdk-sqlite/src/error.rs b/crates/matrix-sdk-sqlite/src/error.rs index 0e4544067..84d89dc5b 100644 --- a/crates/matrix-sdk-sqlite/src/error.rs +++ b/crates/matrix-sdk-sqlite/src/error.rs @@ -69,22 +69,34 @@ pub enum OpenStoreError { pub enum Error { #[error(transparent)] Sqlite(rusqlite::Error), + + #[error("Failed to compute the maximum variable number from {0}")] + SqliteMaximumVariableNumber(i32), + #[error(transparent)] Pool(PoolError), + #[error(transparent)] Encode(rmp_serde::encode::Error), + #[error(transparent)] Decode(rmp_serde::decode::Error), + #[error(transparent)] Json(#[from] serde_json::Error), + #[error(transparent)] Encryption(matrix_sdk_store_encryption::Error), + #[error("can't save/load sessions or group sessions in the store before an account is stored")] AccountUnset, + #[error(transparent)] Pickle(#[from] vodozemac::PickleError), + #[error("An object failed to be decrypted while unpickling")] Unpickle, + #[error("Redaction failed: {0}")] Redaction(#[source] ruma::canonical_json::RedactionError), } diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index c00ef0025..9f49c4c23 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -1,7 +1,11 @@ use std::{ borrow::Cow, + cmp::min, collections::{BTreeMap, BTreeSet}, - fmt, iter, + fmt, + future::Future, + iter, + ops::Not, path::{Path, PathBuf}, sync::Arc, }; @@ -27,7 +31,7 @@ use ruma::{ serde::Raw, CanonicalJsonObject, EventId, OwnedEventId, OwnedUserId, RoomId, RoomVersionId, UserId, }; -use rusqlite::{OptionalExtension, Transaction}; +use rusqlite::{limits::Limit, OptionalExtension, Transaction}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::fs; use tracing::{debug, warn}; @@ -550,14 +554,19 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { } async fn get_kv_blobs(&self, keys: Vec) -> Result>> { - let sql_params = vec!["?"; keys.len()].join(", "); - let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})"); + let keys_length = keys.len(); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(keys))?.mapped(|row| row.get(0)).collect() + chunk_large_query_over(keys, Some(keys_length), |keys| { + let sql_params = repeat_vars(keys.len()); + let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})"); + + let params = rusqlite::params_from_iter(keys); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| row.get(0)).collect() }) - .await?) + }) + .await } async fn set_kv_blob(&self, key: Key, value: Vec) -> Result<()>; @@ -575,16 +584,17 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { }) .await?) } else { - let sql_params = vec!["?"; states.len()].join(", "); - let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})"); + chunk_large_query_over(states, None, |states| { + let sql_params = repeat_vars(states.len()); + let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})"); - Ok(self - .prepare(sql, move |mut stmt| { + self.prepare(sql, move |mut stmt| { stmt.query(rusqlite::params_from_iter(states))? .mapped(|row| row.get(0)) .collect() }) - .await?) + }) + .await } } @@ -594,20 +604,22 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { event_type: Key, state_keys: Vec, ) -> Result)>> { - let sql_params = vec!["?"; state_keys.len()].join(", "); - let sql = format!( - "SELECT stripped, data FROM state_event + chunk_large_query_over(state_keys, None, move |state_keys: Vec| { + let sql_params = repeat_vars(state_keys.len()); + let sql = format!( + "SELECT stripped, data FROM state_event WHERE room_id = ? AND event_type = ? AND state_key IN ({sql_params})" - ); - let params = [room_id, event_type].into_iter().chain(state_keys); + ); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))? - .mapped(|row| Ok((row.get(0)?, row.get(1)?))) - .collect() + let params = rusqlite::params_from_iter( + [room_id.clone(), event_type.clone()].into_iter().chain(state_keys), + ); + + self.prepare(sql, |mut stmt| { + stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect() }) - .await?) + }) + .await } async fn get_maybe_stripped_state_events( @@ -633,19 +645,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { room_id: Key, user_ids: Vec, ) -> Result, Vec)>> { - let sql_params = vec!["?"; user_ids.len()].join(", "); - let sql = format!( - "SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})" - ); - let params = iter::once(room_id).chain(user_ids); + let user_ids_length = user_ids.len(); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))? - .mapped(|row| Ok((row.get(0)?, row.get(1)?))) - .collect() + chunk_large_query_over(user_ids, Some(user_ids_length), move |user_ids| { + let sql_params = repeat_vars(user_ids.len()); + let sql = format!( + "SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})" + ); + + let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(user_ids)); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect() }) - .await?) + }) + .await } async fn get_user_ids(&self, room_id: Key, memberships: Vec) -> Result>> { @@ -655,14 +669,18 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { }) .await? } else { - let sql_params = vec!["?"; memberships.len()].join(", "); - let sql = format!( - "SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})" - ); - let params = iter::once(room_id).chain(memberships); + chunk_large_query_over(memberships, None, move |memberships| { + let sql_params = repeat_vars(memberships.len()); + let sql = format!( + "SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})" + ); - self.prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))?.mapped(|row| row.get(0)).collect() + let params = + rusqlite::params_from_iter(iter::once(room_id.clone()).chain(memberships)); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| row.get(0)).collect() + }) }) .await? }; @@ -701,19 +719,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt { room_id: Key, names: Vec, ) -> Result, Vec)>> { - let sql_params = vec!["?"; names.len()].join(", "); - let sql = format!( - "SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})" - ); - let params = iter::once(room_id).chain(names); + let names_length = names.len(); - Ok(self - .prepare(sql, move |mut stmt| { - stmt.query(rusqlite::params_from_iter(params))? - .mapped(|row| Ok((row.get(0)?, row.get(1)?))) - .collect() + chunk_large_query_over(names, Some(names_length), move |names| { + let sql_params = repeat_vars(names.len()); + let sql = format!( + "SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})" + ); + + let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(names)); + + self.prepare(sql, move |mut stmt| { + stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect() }) - .await?) + }) + .await } async fn get_user_receipt( @@ -1531,6 +1551,55 @@ struct ReceiptData { user_id: OwnedUserId, } +/// Chunk a large query over some keys. +/// +/// Imagine there is a _dynamic_ query that runs potentially large number of +/// parameters, so much that the maximum number of parameters can be hit. Then, +/// this helper is for you. It will execute the query on chunks of parameters. +async fn chunk_large_query_over( + mut keys_to_chunk: Vec, + result_capacity: Option, + do_query: Query, +) -> Result> +where + Query: Fn(Vec) -> Fut, + Fut: Future, rusqlite::Error>>, +{ + let mut all_results = if let Some(capacity) = result_capacity { + Vec::with_capacity(capacity) + } else { + Vec::new() + }; + + // `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then substract 1 in + // case of. + let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 - 1; + let maximum_chunk_size: usize = maximum_chunk_size + .try_into() + .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?; + + while keys_to_chunk.is_empty().not() { + let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size)); + let chunk = keys_to_chunk; + keys_to_chunk = tail; + + all_results.extend(do_query(chunk).await?); + } + + Ok(all_results) +} + +/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated. +fn repeat_vars(count: usize) -> String { + assert_ne!(count, 0); + + let mut str = "?,".repeat(count); + // Remove trailing comma. + str.pop(); + + str +} + #[cfg(test)] mod tests { use std::sync::atomic::{AtomicU32, Ordering::SeqCst}; From d33e2aff2e05abb7c67181b36f393b6386cdfe31 Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Thu, 6 Jul 2023 15:18:41 +0200 Subject: [PATCH 2/5] doc(sqlite): Fix a typo. --- crates/matrix-sdk-sqlite/src/state_store.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 9f49c4c23..25a688248 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -1571,7 +1571,7 @@ where Vec::new() }; - // `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then substract 1 in + // `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then subtract 1 in // case of. let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 - 1; let maximum_chunk_size: usize = maximum_chunk_size From 23dae0612d21c836ca7ffa2ad896671778a8753e Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Thu, 6 Jul 2023 16:22:37 +0200 Subject: [PATCH 3/5] chore(sqlite): Use `itertools` in `repeat_vars`. --- Cargo.lock | 1 + crates/matrix-sdk-sqlite/Cargo.toml | 1 + crates/matrix-sdk-sqlite/src/state_store.rs | 9 +++------ 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cac23e44b..864a3b153 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2870,6 +2870,7 @@ dependencies = [ "ctor", "deadpool-sqlite", "glob", + "itertools 0.11.0", "matrix-sdk-base", "matrix-sdk-crypto", "matrix-sdk-store-encryption", diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index bf6135f9e..cf2055847 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -21,6 +21,7 @@ state-store = [] [dependencies] async-trait = { workspace = true } deadpool-sqlite = "0.5.0" +itertools = { workspace = true } matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base" } matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true } matrix-sdk-store-encryption = { version = "0.2.0", path = "../matrix-sdk-store-encryption" } diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 25a688248..3a0aff4c1 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -12,6 +12,7 @@ use std::{ use async_trait::async_trait; use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime}; +use itertools::Itertools; use matrix_sdk_base::{ deserialized_responses::RawAnySyncOrStrippedState, media::{MediaRequest, UniqueKey}, @@ -1590,14 +1591,10 @@ where } /// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated. -fn repeat_vars(count: usize) -> String { +fn repeat_vars(count: usize) -> impl fmt::Display { assert_ne!(count, 0); - let mut str = "?,".repeat(count); - // Remove trailing comma. - str.pop(); - - str + iter::repeat("?").take(count).format(",") } #[cfg(test)] From 4e20dc5047e49e99c4730d60c261607e0f55c6f2 Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Thu, 6 Jul 2023 16:27:32 +0200 Subject: [PATCH 4/5] chore(sqlite): To make a friend smile. --- crates/matrix-sdk-sqlite/src/state_store.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 3a0aff4c1..85779f0cc 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -5,7 +5,6 @@ use std::{ fmt, future::Future, iter, - ops::Not, path::{Path, PathBuf}, sync::Arc, }; @@ -1579,7 +1578,7 @@ where .try_into() .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?; - while keys_to_chunk.is_empty().not() { + while !keys_to_chunk.is_empty() { let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size)); let chunk = keys_to_chunk; keys_to_chunk = tail; From a70e79420c1b3822b6a4df65ede482b267c101bb Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Thu, 6 Jul 2023 16:46:35 +0200 Subject: [PATCH 5/5] fix(sqlite): Ensure there is free space for static parameters, in `chunk_large_query_over`. --- crates/matrix-sdk-sqlite/src/state_store.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 85779f0cc..8c74f0836 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -1571,9 +1571,9 @@ where Vec::new() }; - // `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then subtract 1 in - // case of. - let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 - 1; + // `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then divide by 2 to + // let space for more static parameters (not part of `keys_to_chunk`). + let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 / 2; let maximum_chunk_size: usize = maximum_chunk_size .try_into() .map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;