diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index 3a9a6e6db..f7ba59e33 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -26,7 +26,7 @@ use matrix_sdk_common::{ }; use ruma::{ api::client::keys::get_keys::v3::Response as KeysQueryResponse, serde::Raw, DeviceId, - OwnedDeviceId, OwnedUserId, UserId, + OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId, }; use tracing::{debug, info, trace, warn}; @@ -40,6 +40,7 @@ use crate::{ requests::KeysQueryRequest, store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, types::{CrossSigningKey, DeviceKeys}, + utilities::FailuresCache, LocalTrust, }; @@ -117,6 +118,7 @@ pub(crate) struct IdentityManager { user_id: Arc, device_id: Arc, keys_query_listener: KeysQueryListener, + failures: FailuresCache, store: Store, } @@ -126,7 +128,13 @@ impl IdentityManager { pub fn new(user_id: Arc, device_id: Arc, store: Store) -> Self { let keys_query_listener = KeysQueryListener::new(store.clone()); - IdentityManager { user_id, device_id, store, keys_query_listener } + IdentityManager { + user_id, + device_id, + store, + keys_query_listener, + failures: Default::default(), + } } fn user_id(&self) -> &UserId { @@ -156,6 +164,23 @@ impl IdentityManager { "Handling a keys query response" ); + // Parse the strings into server names and filter out our own server. We should + // never get failures from our own server but let's remove it as a + // precaution anyways. + let failed_servers = response + .failures + .keys() + .filter_map(|k| ServerName::parse(k).ok()) + .filter(|s| s != self.user_id().server_name()); + let successful_servers = response.device_keys.keys().map(|u| u.server_name()); + + // Append the new failed servers and remove any successful servers. We remove + // the successful servers explicitly so we keep incrementing the delay + // on the failed servers instead of removing them when the entry times + // out. + self.failures.extend(failed_servers); + self.failures.remove(successful_servers); + let devices = self.handle_devices_from_key_query(response.device_keys.clone()).await?; let (identities, cross_signing_identity) = self.handle_cross_singing_keys(response).await?; @@ -166,7 +191,7 @@ impl IdentityManager { ..Default::default() }; - // TODO turn this into a single transaction. + // TODO: turn this into a single transaction. self.store.save_changes(changes).await?; let updated_users: Vec<&UserId> = response.device_keys.keys().map(Deref::deref).collect(); @@ -408,7 +433,7 @@ impl IdentityManager { let mut changes = IdentityChanges::default(); let mut changed_identity = None; - // TODO this is a bit chunky, refactor this into smaller methods. + // TODO: this is a bit chunky, refactor this into smaller methods. for (user_id, master_key) in &response.master_keys { match master_key.deserialize_as::() { @@ -570,7 +595,8 @@ impl IdentityManager { if users.is_empty() { Vec::new() } else { - let users: Vec = users.into_iter().collect(); + let users: Vec = + users.into_iter().filter(|u| !self.failures.contains(u.server_name())).collect(); users .chunks(Self::MAX_KEY_QUERY_USERS) @@ -830,11 +856,44 @@ pub(crate) mod testing { pub(crate) mod tests { use std::time::Duration; - use matrix_sdk_test::async_test; - use ruma::device_id; + use matrix_sdk_test::{async_test, response_from_file}; + use ruma::{ + api::{client::keys::get_keys::v3::Response as KeysQueryResponse, IncomingResponse}, + device_id, user_id, + }; + use serde_json::json; use super::testing::{manager, other_key_query, other_user_id}; + fn key_query_without_failures() -> KeysQueryResponse { + let response = json!({ + "device_keys": { + "@alice:example.org": { + }, + } + }); + + let response = response_from_file(&response); + + KeysQueryResponse::try_from_http_response(response).unwrap() + } + fn key_query_with_failures() -> KeysQueryResponse { + let response = json!({ + "device_keys": { + }, + "failures": { + "example.org": { + "errcode": "M_RESOURCE_LIMIT_EXCEEDED", + "error": "Not yet ready to retry", + } + } + }); + + let response = response_from_file(&response); + + KeysQueryResponse::try_from_http_response(response).unwrap() + } + #[async_test] async fn test_manager_creation() { let manager = manager(); @@ -909,4 +968,49 @@ pub(crate) mod tests { "Our own user is now tracked" ); } + + #[async_test] + async fn failure_handling() { + let manager = manager(); + let alice = user_id!("@alice:example.org"); + + assert!(manager.store.tracked_users().is_empty(), "No users are initially tracked"); + manager.store.update_tracked_user(alice, true).await.unwrap(); + + assert!( + manager.store.tracked_users().contains(alice), + "Alice is tracked after being marked as tracked" + ); + assert!(manager + .users_for_key_query() + .await + .iter() + .any(|r| r.device_keys.contains_key(alice))); + + let response = key_query_with_failures(); + + manager.receive_keys_query_response(&response).await.unwrap(); + assert!(manager.failures.contains(alice.server_name())); + assert!(!manager + .users_for_key_query() + .await + .iter() + .any(|r| r.device_keys.contains_key(alice))); + + let response = key_query_without_failures(); + manager.receive_keys_query_response(&response).await.unwrap(); + assert!(!manager.failures.contains(alice.server_name())); + assert!(!manager + .users_for_key_query() + .await + .iter() + .any(|r| r.device_keys.contains_key(alice))); + + manager.store.update_tracked_user(alice, true).await.unwrap(); + assert!(manager + .users_for_key_query() + .await + .iter() + .any(|r| r.device_keys.contains_key(alice))); + } }