mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-15 03:25:46 -04:00
fix(sqlite): Prevent reaching the maximum number of parameters in an SQL statement
fix(sqlite): Prevent reaching the maximum number of parameters in an SQL statement
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2870,6 +2870,7 @@ dependencies = [
|
||||
"ctor",
|
||||
"deadpool-sqlite",
|
||||
"glob",
|
||||
"itertools 0.11.0",
|
||||
"matrix-sdk-base",
|
||||
"matrix-sdk-crypto",
|
||||
"matrix-sdk-store-encryption",
|
||||
|
||||
@@ -21,12 +21,13 @@ state-store = []
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
deadpool-sqlite = "0.5.0"
|
||||
itertools = { workspace = true }
|
||||
matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base" }
|
||||
matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true }
|
||||
matrix-sdk-store-encryption = { version = "0.2.0", path = "../matrix-sdk-store-encryption" }
|
||||
rmp-serde = "1.1.1"
|
||||
ruma = { workspace = true }
|
||||
rusqlite = "0.28.0"
|
||||
rusqlite = { version = "0.28.0", features = ["limits"] }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -69,22 +69,34 @@ pub enum OpenStoreError {
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
Sqlite(rusqlite::Error),
|
||||
|
||||
#[error("Failed to compute the maximum variable number from {0}")]
|
||||
SqliteMaximumVariableNumber(i32),
|
||||
|
||||
#[error(transparent)]
|
||||
Pool(PoolError),
|
||||
|
||||
#[error(transparent)]
|
||||
Encode(rmp_serde::encode::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Decode(rmp_serde::decode::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Encryption(matrix_sdk_store_encryption::Error),
|
||||
|
||||
#[error("can't save/load sessions or group sessions in the store before an account is stored")]
|
||||
AccountUnset,
|
||||
|
||||
#[error(transparent)]
|
||||
Pickle(#[from] vodozemac::PickleError),
|
||||
|
||||
#[error("An object failed to be decrypted while unpickling")]
|
||||
Unpickle,
|
||||
|
||||
#[error("Redaction failed: {0}")]
|
||||
Redaction(#[source] ruma::canonical_json::RedactionError),
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::min,
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
fmt, iter,
|
||||
fmt,
|
||||
future::Future,
|
||||
iter,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use deadpool_sqlite::{Object as SqliteConn, Pool as SqlitePool, Runtime};
|
||||
use itertools::Itertools;
|
||||
use matrix_sdk_base::{
|
||||
deserialized_responses::RawAnySyncOrStrippedState,
|
||||
media::{MediaRequest, UniqueKey},
|
||||
@@ -27,7 +31,7 @@ use ruma::{
|
||||
serde::Raw,
|
||||
CanonicalJsonObject, EventId, OwnedEventId, OwnedUserId, RoomId, RoomVersionId, UserId,
|
||||
};
|
||||
use rusqlite::{OptionalExtension, Transaction};
|
||||
use rusqlite::{limits::Limit, OptionalExtension, Transaction};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::fs;
|
||||
use tracing::{debug, warn};
|
||||
@@ -550,14 +554,19 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
|
||||
}
|
||||
|
||||
async fn get_kv_blobs(&self, keys: Vec<Key>) -> Result<Vec<Vec<u8>>> {
|
||||
let sql_params = vec!["?"; keys.len()].join(", ");
|
||||
let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})");
|
||||
let keys_length = keys.len();
|
||||
|
||||
Ok(self
|
||||
.prepare(sql, move |mut stmt| {
|
||||
stmt.query(rusqlite::params_from_iter(keys))?.mapped(|row| row.get(0)).collect()
|
||||
chunk_large_query_over(keys, Some(keys_length), |keys| {
|
||||
let sql_params = repeat_vars(keys.len());
|
||||
let sql = format!("SELECT value FROM kv_blob WHERE key IN ({sql_params})");
|
||||
|
||||
let params = rusqlite::params_from_iter(keys);
|
||||
|
||||
self.prepare(sql, move |mut stmt| {
|
||||
stmt.query(params)?.mapped(|row| row.get(0)).collect()
|
||||
})
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn set_kv_blob(&self, key: Key, value: Vec<u8>) -> Result<()>;
|
||||
@@ -575,16 +584,17 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
|
||||
})
|
||||
.await?)
|
||||
} else {
|
||||
let sql_params = vec!["?"; states.len()].join(", ");
|
||||
let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})");
|
||||
chunk_large_query_over(states, None, |states| {
|
||||
let sql_params = repeat_vars(states.len());
|
||||
let sql = format!("SELECT data FROM room_info WHERE state IN ({sql_params})");
|
||||
|
||||
Ok(self
|
||||
.prepare(sql, move |mut stmt| {
|
||||
self.prepare(sql, move |mut stmt| {
|
||||
stmt.query(rusqlite::params_from_iter(states))?
|
||||
.mapped(|row| row.get(0))
|
||||
.collect()
|
||||
})
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,20 +604,22 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
|
||||
event_type: Key,
|
||||
state_keys: Vec<Key>,
|
||||
) -> Result<Vec<(bool, Vec<u8>)>> {
|
||||
let sql_params = vec!["?"; state_keys.len()].join(", ");
|
||||
let sql = format!(
|
||||
"SELECT stripped, data FROM state_event
|
||||
chunk_large_query_over(state_keys, None, move |state_keys: Vec<Key>| {
|
||||
let sql_params = repeat_vars(state_keys.len());
|
||||
let sql = format!(
|
||||
"SELECT stripped, data FROM state_event
|
||||
WHERE room_id = ? AND event_type = ? AND state_key IN ({sql_params})"
|
||||
);
|
||||
let params = [room_id, event_type].into_iter().chain(state_keys);
|
||||
);
|
||||
|
||||
Ok(self
|
||||
.prepare(sql, move |mut stmt| {
|
||||
stmt.query(rusqlite::params_from_iter(params))?
|
||||
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.collect()
|
||||
let params = rusqlite::params_from_iter(
|
||||
[room_id.clone(), event_type.clone()].into_iter().chain(state_keys),
|
||||
);
|
||||
|
||||
self.prepare(sql, |mut stmt| {
|
||||
stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
|
||||
})
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_maybe_stripped_state_events(
|
||||
@@ -633,19 +645,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
|
||||
room_id: Key,
|
||||
user_ids: Vec<Key>,
|
||||
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
|
||||
let sql_params = vec!["?"; user_ids.len()].join(", ");
|
||||
let sql = format!(
|
||||
"SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})"
|
||||
);
|
||||
let params = iter::once(room_id).chain(user_ids);
|
||||
let user_ids_length = user_ids.len();
|
||||
|
||||
Ok(self
|
||||
.prepare(sql, move |mut stmt| {
|
||||
stmt.query(rusqlite::params_from_iter(params))?
|
||||
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.collect()
|
||||
chunk_large_query_over(user_ids, Some(user_ids_length), move |user_ids| {
|
||||
let sql_params = repeat_vars(user_ids.len());
|
||||
let sql = format!(
|
||||
"SELECT user_id, data FROM profile WHERE room_id = ? AND user_id IN ({sql_params})"
|
||||
);
|
||||
|
||||
let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(user_ids));
|
||||
|
||||
self.prepare(sql, move |mut stmt| {
|
||||
stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
|
||||
})
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_user_ids(&self, room_id: Key, memberships: Vec<Key>) -> Result<Vec<Vec<u8>>> {
|
||||
@@ -655,14 +669,18 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
|
||||
})
|
||||
.await?
|
||||
} else {
|
||||
let sql_params = vec!["?"; memberships.len()].join(", ");
|
||||
let sql = format!(
|
||||
"SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})"
|
||||
);
|
||||
let params = iter::once(room_id).chain(memberships);
|
||||
chunk_large_query_over(memberships, None, move |memberships| {
|
||||
let sql_params = repeat_vars(memberships.len());
|
||||
let sql = format!(
|
||||
"SELECT data FROM member WHERE room_id = ? AND membership IN ({sql_params})"
|
||||
);
|
||||
|
||||
self.prepare(sql, move |mut stmt| {
|
||||
stmt.query(rusqlite::params_from_iter(params))?.mapped(|row| row.get(0)).collect()
|
||||
let params =
|
||||
rusqlite::params_from_iter(iter::once(room_id.clone()).chain(memberships));
|
||||
|
||||
self.prepare(sql, move |mut stmt| {
|
||||
stmt.query(params)?.mapped(|row| row.get(0)).collect()
|
||||
})
|
||||
})
|
||||
.await?
|
||||
};
|
||||
@@ -701,19 +719,21 @@ trait SqliteObjectStateStoreExt: SqliteObjectExt {
|
||||
room_id: Key,
|
||||
names: Vec<Key>,
|
||||
) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
|
||||
let sql_params = vec!["?"; names.len()].join(", ");
|
||||
let sql = format!(
|
||||
"SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})"
|
||||
);
|
||||
let params = iter::once(room_id).chain(names);
|
||||
let names_length = names.len();
|
||||
|
||||
Ok(self
|
||||
.prepare(sql, move |mut stmt| {
|
||||
stmt.query(rusqlite::params_from_iter(params))?
|
||||
.mapped(|row| Ok((row.get(0)?, row.get(1)?)))
|
||||
.collect()
|
||||
chunk_large_query_over(names, Some(names_length), move |names| {
|
||||
let sql_params = repeat_vars(names.len());
|
||||
let sql = format!(
|
||||
"SELECT name, data FROM display_name WHERE room_id = ? AND name IN ({sql_params})"
|
||||
);
|
||||
|
||||
let params = rusqlite::params_from_iter(iter::once(room_id.clone()).chain(names));
|
||||
|
||||
self.prepare(sql, move |mut stmt| {
|
||||
stmt.query(params)?.mapped(|row| Ok((row.get(0)?, row.get(1)?))).collect()
|
||||
})
|
||||
.await?)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_user_receipt(
|
||||
@@ -1531,6 +1551,51 @@ struct ReceiptData {
|
||||
user_id: OwnedUserId,
|
||||
}
|
||||
|
||||
/// Chunk a large query over some keys.
|
||||
///
|
||||
/// Imagine there is a _dynamic_ query that runs potentially large number of
|
||||
/// parameters, so much that the maximum number of parameters can be hit. Then,
|
||||
/// this helper is for you. It will execute the query on chunks of parameters.
|
||||
async fn chunk_large_query_over<Query, Fut, Res>(
|
||||
mut keys_to_chunk: Vec<Key>,
|
||||
result_capacity: Option<usize>,
|
||||
do_query: Query,
|
||||
) -> Result<Vec<Res>>
|
||||
where
|
||||
Query: Fn(Vec<Key>) -> Fut,
|
||||
Fut: Future<Output = Result<Vec<Res>, rusqlite::Error>>,
|
||||
{
|
||||
let mut all_results = if let Some(capacity) = result_capacity {
|
||||
Vec::with_capacity(capacity)
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// `Limit` has a `repr(i32)`, it's safe to cast it to `i32`. Then divide by 2 to
|
||||
// let space for more static parameters (not part of `keys_to_chunk`).
|
||||
let maximum_chunk_size = Limit::SQLITE_LIMIT_VARIABLE_NUMBER as i32 / 2;
|
||||
let maximum_chunk_size: usize = maximum_chunk_size
|
||||
.try_into()
|
||||
.map_err(|_| Error::SqliteMaximumVariableNumber(maximum_chunk_size))?;
|
||||
|
||||
while !keys_to_chunk.is_empty() {
|
||||
let tail = keys_to_chunk.split_off(min(keys_to_chunk.len(), maximum_chunk_size));
|
||||
let chunk = keys_to_chunk;
|
||||
keys_to_chunk = tail;
|
||||
|
||||
all_results.extend(do_query(chunk).await?);
|
||||
}
|
||||
|
||||
Ok(all_results)
|
||||
}
|
||||
|
||||
/// Repeat `?` n times, where n is defined by `count`. `?` are comma-separated.
|
||||
fn repeat_vars(count: usize) -> impl fmt::Display {
|
||||
assert_ne!(count, 0);
|
||||
|
||||
iter::repeat("?").take(count).format(",")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
|
||||
|
||||
Reference in New Issue
Block a user