feat(crypto): Handle the failures field in the /keys/query response

This commit is contained in:
Damir Jelić
2023-01-03 12:32:46 +01:00
parent 1b18402707
commit 8db84a3c9e

View File

@@ -26,7 +26,7 @@ use matrix_sdk_common::{
};
use ruma::{
api::client::keys::get_keys::v3::Response as KeysQueryResponse, serde::Raw, DeviceId,
OwnedDeviceId, OwnedUserId, UserId,
OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId,
};
use tracing::{debug, info, trace, warn};
@@ -40,6 +40,7 @@ use crate::{
requests::KeysQueryRequest,
store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store},
types::{CrossSigningKey, DeviceKeys},
utilities::FailuresCache,
LocalTrust,
};
@@ -117,6 +118,7 @@ pub(crate) struct IdentityManager {
user_id: Arc<UserId>,
device_id: Arc<DeviceId>,
keys_query_listener: KeysQueryListener,
failures: FailuresCache<OwnedServerName>,
store: Store,
}
@@ -126,7 +128,13 @@ impl IdentityManager {
pub fn new(user_id: Arc<UserId>, device_id: Arc<DeviceId>, store: Store) -> Self {
let keys_query_listener = KeysQueryListener::new(store.clone());
IdentityManager { user_id, device_id, store, keys_query_listener }
IdentityManager {
user_id,
device_id,
store,
keys_query_listener,
failures: Default::default(),
}
}
fn user_id(&self) -> &UserId {
@@ -156,6 +164,23 @@ impl IdentityManager {
"Handling a keys query response"
);
// Parse the strings into server names and filter out our own server. We should
// never get failures from our own server but let's remove it as a
// precaution anyways.
let failed_servers = response
.failures
.keys()
.filter_map(|k| ServerName::parse(k).ok())
.filter(|s| s != self.user_id().server_name());
let successful_servers = response.device_keys.keys().map(|u| u.server_name());
// Append the new failed servers and remove any successful servers. We remove
// the successful servers explicitly so we keep incrementing the delay
// on the failed servers instead of removing them when the entry times
// out.
self.failures.extend(failed_servers);
self.failures.remove(successful_servers);
let devices = self.handle_devices_from_key_query(response.device_keys.clone()).await?;
let (identities, cross_signing_identity) = self.handle_cross_singing_keys(response).await?;
@@ -166,7 +191,7 @@ impl IdentityManager {
..Default::default()
};
// TODO turn this into a single transaction.
// TODO: turn this into a single transaction.
self.store.save_changes(changes).await?;
let updated_users: Vec<&UserId> = response.device_keys.keys().map(Deref::deref).collect();
@@ -408,7 +433,7 @@ impl IdentityManager {
let mut changes = IdentityChanges::default();
let mut changed_identity = None;
// TODO this is a bit chunky, refactor this into smaller methods.
// TODO: this is a bit chunky, refactor this into smaller methods.
for (user_id, master_key) in &response.master_keys {
match master_key.deserialize_as::<CrossSigningKey>() {
@@ -570,7 +595,8 @@ impl IdentityManager {
if users.is_empty() {
Vec::new()
} else {
let users: Vec<OwnedUserId> = users.into_iter().collect();
let users: Vec<OwnedUserId> =
users.into_iter().filter(|u| !self.failures.contains(u.server_name())).collect();
users
.chunks(Self::MAX_KEY_QUERY_USERS)
@@ -830,11 +856,44 @@ pub(crate) mod testing {
pub(crate) mod tests {
use std::time::Duration;
use matrix_sdk_test::async_test;
use ruma::device_id;
use matrix_sdk_test::{async_test, response_from_file};
use ruma::{
api::{client::keys::get_keys::v3::Response as KeysQueryResponse, IncomingResponse},
device_id, user_id,
};
use serde_json::json;
use super::testing::{manager, other_key_query, other_user_id};
fn key_query_without_failures() -> KeysQueryResponse {
let response = json!({
"device_keys": {
"@alice:example.org": {
},
}
});
let response = response_from_file(&response);
KeysQueryResponse::try_from_http_response(response).unwrap()
}
fn key_query_with_failures() -> KeysQueryResponse {
let response = json!({
"device_keys": {
},
"failures": {
"example.org": {
"errcode": "M_RESOURCE_LIMIT_EXCEEDED",
"error": "Not yet ready to retry",
}
}
});
let response = response_from_file(&response);
KeysQueryResponse::try_from_http_response(response).unwrap()
}
#[async_test]
async fn test_manager_creation() {
let manager = manager();
@@ -909,4 +968,49 @@ pub(crate) mod tests {
"Our own user is now tracked"
);
}
#[async_test]
async fn failure_handling() {
let manager = manager();
let alice = user_id!("@alice:example.org");
assert!(manager.store.tracked_users().is_empty(), "No users are initially tracked");
manager.store.update_tracked_user(alice, true).await.unwrap();
assert!(
manager.store.tracked_users().contains(alice),
"Alice is tracked after being marked as tracked"
);
assert!(manager
.users_for_key_query()
.await
.iter()
.any(|r| r.device_keys.contains_key(alice)));
let response = key_query_with_failures();
manager.receive_keys_query_response(&response).await.unwrap();
assert!(manager.failures.contains(alice.server_name()));
assert!(!manager
.users_for_key_query()
.await
.iter()
.any(|r| r.device_keys.contains_key(alice)));
let response = key_query_without_failures();
manager.receive_keys_query_response(&response).await.unwrap();
assert!(!manager.failures.contains(alice.server_name()));
assert!(!manager
.users_for_key_query()
.await
.iter()
.any(|r| r.device_keys.contains_key(alice)));
manager.store.update_tracked_user(alice, true).await.unwrap();
assert!(manager
.users_for_key_query()
.await
.iter()
.any(|r| r.device_keys.contains_key(alice)));
}
}