diff --git a/Cargo.lock b/Cargo.lock index 232ca8abb..92e73a0a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2745,6 +2745,7 @@ dependencies = [ "hmac", "http", "indoc", + "itertools", "matrix-sdk-common", "matrix-sdk-qrcode", "matrix-sdk-test", diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index c3a5e3283..f8e8d0d74 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -39,6 +39,7 @@ futures-core = "0.3.24" futures-util = { workspace = true } hmac = "0.12.1" http = { workspace = true, optional = true } # feature = testing only +itertools = "0.10.5" matrix-sdk-qrcode = { version = "0.4.0", path = "../matrix-sdk-qrcode", optional = true } matrix-sdk-common = { version = "0.6.0", path = "../matrix-sdk-common" } olm-rs = { version = "2.2.0", features = ["serde"], optional = true } diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index 875a5ba9c..30da8a8fb 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -20,6 +20,7 @@ use std::{ }; use futures_util::future::join_all; +use itertools::Itertools; use matrix_sdk_common::{ executor::spawn, locks::Mutex, @@ -694,7 +695,7 @@ impl IdentityManager { /// /// # Returns /// - /// A list of pairs of `(request_id, request)` + /// A map of a request ID to the `/keys/query` request. /// /// The response of a successful key query requests needs to be passed to /// the [`OlmMachine`] with the [`receive_keys_query_response`]. @@ -703,7 +704,7 @@ impl IdentityManager { /// [`receive_keys_query_response`]: #method.receive_keys_query_response pub async fn users_for_key_query( &self, - ) -> StoreResult> { + ) -> StoreResult> { // Forget about any previous key queries in flight. *self.keys_query_request_details.lock().await = None; @@ -723,25 +724,44 @@ impl IdentityManager { }; if users.is_empty() { - return Ok(Vec::new()); + Ok(BTreeMap::new()) + } else { + // Let's remove users that are part of the `FailuresCache`. The cache, which is + // a TTL cache, remembers users for which a previous `/key/query` request has + // failed. We don't retry a `/keys/query` for such users for a + // certain amount of time. + let users = users.into_iter().filter(|u| !self.failures.contains(u.server_name())); + + // We don't want to create a single `/keys/query` request with an infinite + // amount of users. Some servers will likely bail out after a + // certain amount of users and the responses will be large. In the + // case of a transmission error, we'll have to retransmit the large + // response. + // + // Convert the set of users into multiple /keys/query requests. + let requests: BTreeMap<_, _> = users + .chunks(Self::MAX_KEY_QUERY_USERS) + .into_iter() + .map(|user_chunk| { + let request_id = TransactionId::new(); + let request = KeysQueryRequest::new(user_chunk); + + debug!(?request_id, users = ?request.device_keys.keys(), "Created a /keys/query request"); + + (request_id, request) + }) + .collect(); + + // Collect the request IDs, these will be used later in the + // `receive_keys_query_response()` method to figure out if the user can be + // marked as up-to-date/non-dirty. + let request_ids = requests.keys().cloned().collect(); + let request_details = KeysQueryRequestDetails { sequence_number, request_ids }; + + *self.keys_query_request_details.lock().await = Some(request_details); + + Ok(requests) } - - let mut result: Vec<(OwnedTransactionId, KeysQueryRequest)> = Vec::new(); - let mut request_details = - KeysQueryRequestDetails { sequence_number, request_ids: HashSet::new() }; - - let filtered_users: Vec = - users.into_iter().filter(|u| !self.failures.contains(u.server_name())).collect(); - - for c in filtered_users.chunks(Self::MAX_KEY_QUERY_USERS) { - let request_id = TransactionId::new(); - let req = KeysQueryRequest::new(c.iter().map(|u| (u.clone(), Vec::new())).collect()); - debug!(?request_id, users = ?c, "Created keys query request"); - result.push((request_id.clone(), req)); - request_details.request_ids.insert(request_id); - } - *(self.keys_query_request_details.lock().await) = Some(request_details); - Ok(result) } /// Receive the list of users that contained changed devices from the @@ -1161,7 +1181,7 @@ pub(crate) mod tests { manager.update_tracked_users([alice]).await.unwrap(); // alice should be in the list of key queries - let (reqid, req) = manager.users_for_key_query().await.unwrap().pop().unwrap(); + let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); assert!(req.device_keys.contains_key(alice)); // another invalidation turns up @@ -1171,7 +1191,7 @@ pub(crate) mod tests { manager.receive_keys_query_response(&reqid, &other_key_query()).await.unwrap(); // alice should *still* be in the list of key queries - let (reqid, req) = manager.users_for_key_query().await.unwrap().pop().unwrap(); + let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); assert!(req.device_keys.contains_key(alice)); // another key query response @@ -1197,7 +1217,7 @@ pub(crate) mod tests { manager.store.tracked_users().await.unwrap().contains(alice), "Alice is tracked after being marked as tracked" ); - let (reqid, req) = manager.users_for_key_query().await.unwrap().pop().unwrap(); + let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); assert!(req.device_keys.contains_key(alice)); // a failure should stop us querying for the user's keys. diff --git a/crates/matrix-sdk-crypto/src/requests.rs b/crates/matrix-sdk-crypto/src/requests.rs index 5cec2b344..068cd5874 100644 --- a/crates/matrix-sdk-crypto/src/requests.rs +++ b/crates/matrix-sdk-crypto/src/requests.rs @@ -192,7 +192,9 @@ pub struct KeysQueryRequest { } impl KeysQueryRequest { - pub(crate) fn new(device_keys: BTreeMap>) -> Self { + pub(crate) fn new(users: impl Iterator) -> Self { + let device_keys = users.map(|u| (u, Vec::new())).collect(); + Self { timeout: None, device_keys, token: None } } }