fix(sqlite): Prevent reaching the maximum number of parameters in an SQL statement

fix(sqlite): Prevent reaching the maximum number of parameters in an SQL statement
This commit is contained in:
Ivan Enderlin
2023-07-06 17:07:23 +02:00
committed by GitHub
4 changed files with 133 additions and 54 deletions

1
Cargo.lock generated
View File

@@ -2870,6 +2870,7 @@ dependencies = [
"ctor",
"deadpool-sqlite",
"glob",
"itertools 0.11.0",
"matrix-sdk-base",
"matrix-sdk-crypto",
"matrix-sdk-store-encryption",

View File

@@ -21,12 +21,13 @@ state-store = []
[dependencies]
async-trait = { workspace = true }
deadpool-sqlite = "0.5.0"
itertools = { workspace = true }
matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base" }
matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true }
matrix-sdk-store-encryption = { version = "0.2.0", path = "../matrix-sdk-store-encryption" }
rmp-serde = "1.1.1"
ruma = { workspace = true }
rusqlite = "0.28.0"
rusqlite = { version = "0.28.0", features = ["limits"] }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }

View File

@@ -69,22 +69,34 @@ pub enum OpenStoreError {
pub enum Error {
#[error(transparent)]
Sqlite(rusqlite::Error),
#[error("Failed to compute the maximum variable number from {0}")]
SqliteMaximumVariableNumber(i32),
#[error(transparent)]
Pool(PoolError),
#[error(transparent)]
Encode(rmp_serde::encode::Error),
#[error(transparent)]
Decode(rmp_serde::decode::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Encryption(matrix_sdk_store_encryption::Error),
#[error("can't save/load sessions or group sessions in the store before an account is stored")]
AccountUnset,
#[error(transparent)]
Pickle(#[from] vodozemac::PickleError),
#[error("An object failed to be decrypted while unpickling")]
Unpickle,
#[error("Redaction failed: {0}")]
Redaction(#[source] ruma::canonical_json::RedactionError),
}

View File

@@ -1,13 +1,17 @@
use std::{
borrow::Cow,
cmp::min,
collections::{BTreeMap, BTreeSet},
fmt, iter,
fmt,
future::Future,
iter,
path::{Path, PathBuf},
sync::Arc,
};
use async_trait::async_trait;
use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime};
use itertools::Itertools;
use matrix_sdk_base::{
deserialized_responses::RawAnySyncOrStrippedState,
media::{MediaRequest, UniqueKey},
@@ -27,7 +31,7 @@ use ruma::{
serde::Raw,
CanonicalJsonObject, EventId, OwnedEventId, OwnedUserId, RoomId, RoomVersionId, UserId,
};
use rusqlite::{OptionalExtension, Transaction};
use rusqlite::{limits::Limit, OptionalExtension, Transaction};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::fs;
use tracing::{debug, warn};
@@ -550,14 +554,19 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
}
async fn get_kv_blobs(&self, keys: Vec<Key>) -> Result<Vec<Vec<u8>>> {
let sql_params = vec!["?"; keys.len()].join(", ");
let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})");
let keys_length = keys.len();
Ok(self
.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(keys))?.mapped(|row| row.get(0)).collect()
chunk_large_query_over(keys, Some(keys_length), |keys| {
let sql_params = repeat_vars(keys.len());
let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})");
let params = rusqlite::params_from_iter(keys);
self.prepare(sql, move |mut stmt| {
stmt.query(params)?.mapped(|row| row.get(0)).collect()
})
.await?)
})
.await
}
async fn set_kv_blob(&self, key: Key, value: Vec<u8>) -> Result<()>;
@@ -575,16 +584,17 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
})
.await?)
} else {
let sql_params = vec!["?"; states.len()].join(", ");
let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})");
chunk_large_query_over(states, None, |states| {
let sql_params = repeat_vars(states.len());
let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})");
Ok(self
.prepare(sql, move |mut stmt| {
self.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(states))?
.mapped(|row| row.get(0))
.collect()
})
.await?)
})
.await
}
}
@@ -594,20 +604,22 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
event_type: Key,
state_keys: Vec<Key>,
) -> Result<Vec<(bool, Vec<u8>)>> {
let sql_params = vec!["?"; state_keys.len()].join(", ");
let sql = format!(
"SELECT stripped, data FROM state_event
chunk_large_query_over(state_keys, None, move |state_keys: Vec<Key>| {
let sql_params = repeat_vars(state_keys.len());
let sql = format!(
"SELECT stripped, data FROM state_event
WHERE room_id = ? AND event_type = ? AND state_key IN ({sql_params})"
);
let params = [room_id, event_type].into_iter().chain(state_keys);
);
Ok(self
.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(params))?
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
.collect()
let params = rusqlite::params_from_iter(
[room_id.clone(), event_type.clone()].into_iter().chain(state_keys),
);
self.prepare(sql, |mut stmt| {
stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
})
.await?)
})
.await
}
async fn get_maybe_stripped_state_events(
@@ -633,19 +645,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
room_id: Key,
user_ids: Vec<Key>,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let sql_params = vec!["?"; user_ids.len()].join(", ");
let sql = format!(
"SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})"
);
let params = iter::once(room_id).chain(user_ids);
let user_ids_length = user_ids.len();
Ok(self
.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(params))?
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
.collect()
chunk_large_query_over(user_ids, Some(user_ids_length), move |user_ids| {
let sql_params = repeat_vars(user_ids.len());
let sql = format!(
"SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})"
);
let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(user_ids));
self.prepare(sql, move |mut stmt| {
stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
})
.await?)
})
.await
}
async fn get_user_ids(&self, room_id: Key, memberships: Vec<Key>) -> Result<Vec<Vec<u8>>> {
@@ -655,14 +669,18 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
})
.await?
} else {
let sql_params = vec!["?"; memberships.len()].join(", ");
let sql = format!(
"SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})"
);
let params = iter::once(room_id).chain(memberships);
chunk_large_query_over(memberships, None, move |memberships| {
let sql_params = repeat_vars(memberships.len());
let sql = format!(
"SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})"
);
self.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(params))?.mapped(|row| row.get(0)).collect()
let params =
rusqlite::params_from_iter(iter::once(room_id.clone()).chain(memberships));
self.prepare(sql, move |mut stmt| {
stmt.query(params)?.mapped(|row| row.get(0)).collect()
})
})
.await?
};
@@ -701,19 +719,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
room_id: Key,
names: Vec<Key>,
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
let sql_params = vec!["?"; names.len()].join(", ");
let sql = format!(
"SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})"
);
let params = iter::once(room_id).chain(names);
let names_length = names.len();
Ok(self
.prepare(sql, move |mut stmt| {
stmt.query(rusqlite::params_from_iter(params))?
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
.collect()
chunk_large_query_over(names, Some(names_length), move |names| {
let sql_params = repeat_vars(names.len());
let sql = format!(
"SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})"
);
let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(names));
self.prepare(sql, move |mut stmt| {
stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
})
.await?)
})
.await
}
async fn get_user_receipt(
@@ -1531,6 +1551,51 @@ struct ReceiptData {
user_id: OwnedUserId,
}
/// Chunk a large query over some keys.
///
/// Imagine there is a _dynamic_ query that runs potentially large number of
/// parameters, so much that the maximum number of parameters can be hit. Then,
/// this helper is for you. It will execute the query on chunks of parameters.
async fn chunk_large_query_over<Query, Fut, Res>(
mut keys_to_chunk: Vec<Key>,
result_capacity: Option<usize>,
do_query: Query,
) -> Result<Vec<Res>>
where
Query: Fn(Vec<Key>) -> Fut,
Fut: Future<Output = Result<Vec<Res>, rusqlite::Error>>,
{
let mut all_results = if let Some(capacity) = result_capacity {
Vec::with_capacity(capacity)
} else {
Vec::new()
};
// `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then divide by 2 to
// let space for more static parameters (not part of `keys_to_chunk`).
let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 / 2;
let maximum_chunk_size: usize = maximum_chunk_size
.try_into()
.map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
while !keys_to_chunk.is_empty() {
let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
let chunk = keys_to_chunk;
keys_to_chunk = tail;
all_results.extend(do_query(chunk).await?);
}
Ok(all_results)
}
/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated.
fn repeat_vars(count: usize) -> impl fmt::Display {
assert_ne!(count, 0);
iter::repeat("?").take(count).format(",")
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};