crypto: move Store::mark_user_as_changed to the StoreCache too

This commit is contained in:
Benjamin Bouvier
2023-10-16 17:56:54 +02:00
parent 537ac8a269
commit f580966408
6 changed files with 121 additions and 63 deletions

View File

@@ -46,7 +46,7 @@ use crate::{
olm::{InboundGroupSession, Session},
requests::{OutgoingRequest, ToDeviceRequest},
session_manager::GroupSessionCache,
store::{Changes, CryptoStoreError, SecretImportError, Store},
store::{Changes, CryptoStoreError, SecretImportError, Store, StoreCache},
types::events::{
forwarded_room_key::ForwardedRoomKeyContent,
olm_v1::{DecryptedForwardedRoomKeyEvent, DecryptedSecretSendEvent},
@@ -190,7 +190,10 @@ impl GossipMachine {
/// Handle all the incoming key requests that are queued up and empty our
/// key request queue.
pub async fn collect_incoming_key_requests(&self) -> OlmResult<Vec<Session>> {
pub async fn collect_incoming_key_requests(
&self,
cache: &StoreCache,
) -> OlmResult<Vec<Session>> {
let mut changed_sessions = Vec::new();
let incoming_key_requests =
@@ -199,8 +202,8 @@ impl GossipMachine {
for event in incoming_key_requests.values() {
if let Some(s) = match event {
#[cfg(feature = "automatic-room-key-forwarding")]
RequestEvent::KeyShare(e) => Box::pin(self.handle_key_request(e)).await?,
RequestEvent::Secret(e) => Box::pin(self.handle_secret_request(e)).await?,
RequestEvent::KeyShare(e) => Box::pin(self.handle_key_request(cache, e)).await?,
RequestEvent::Secret(e) => Box::pin(self.handle_secret_request(cache, e)).await?,
#[cfg(not(feature = "automatic-room-key-forwarding"))]
_ => None,
} {
@@ -256,6 +259,7 @@ impl GossipMachine {
async fn handle_secret_request(
&self,
cache: &StoreCache,
event: &SecretRequestEvent,
) -> OlmResult<Option<Session>> {
let secret_name = match &event.content.action {
@@ -331,7 +335,7 @@ impl GossipMachine {
?secret_name,
"Received a secret request form an unknown device",
);
self.inner.store.mark_user_as_changed(&event.sender).await?;
cache.mark_user_as_changed(&self.inner.store, &event.sender).await?;
None
})
@@ -380,6 +384,7 @@ impl GossipMachine {
#[cfg(feature = "automatic-room-key-forwarding")]
async fn answer_room_key_request(
&self,
cache: &StoreCache,
event: &RoomKeyRequestEvent,
session: &InboundGroupSession,
) -> OlmResult<Option<Session>> {
@@ -390,7 +395,7 @@ impl GossipMachine {
let Some(device) = device else {
warn!("Received a key request from an unknown device");
self.inner.store.mark_user_as_changed(&event.sender).await?;
cache.mark_user_as_changed(&self.inner.store, &event.sender).await?;
return Ok(None);
};
@@ -429,6 +434,7 @@ impl GossipMachine {
)]
async fn handle_supported_key_request(
&self,
cache: &StoreCache,
event: &RoomKeyRequestEvent,
room_id: &RoomId,
session_id: &str,
@@ -436,7 +442,7 @@ impl GossipMachine {
let session = self.inner.store.get_inbound_group_session(room_id, session_id).await?;
if let Some(s) = session {
self.answer_room_key_request(event, &s).await
self.answer_room_key_request(cache, event, &s).await
} else {
debug!("Received a room key request for an unknown inbound group session",);
@@ -446,14 +452,19 @@ impl GossipMachine {
/// Handle a single incoming key request.
#[cfg(feature = "automatic-room-key-forwarding")]
async fn handle_key_request(&self, event: &RoomKeyRequestEvent) -> OlmResult<Option<Session>> {
async fn handle_key_request(
&self,
cache: &StoreCache,
event: &RoomKeyRequestEvent,
) -> OlmResult<Option<Session>> {
use crate::types::events::room_key_request::{Action, RequestedKeyInfo};
if self.inner.room_key_forwarding_enabled.load(Ordering::SeqCst) {
match &event.content.action {
Action::Request(info) => match info {
RequestedKeyInfo::MegolmV1AesSha2(i) => {
self.handle_supported_key_request(event, &i.room_id, &i.session_id).await
self.handle_supported_key_request(cache, event, &i.room_id, &i.session_id)
.await
}
#[cfg(feature = "experimental-algorithms")]
RequestedKeyInfo::MegolmV2AesSha2(i) => {
@@ -832,6 +843,7 @@ impl GossipMachine {
async fn receive_secret(
&self,
cache: &StoreCache,
sender_key: Curve25519PublicKey,
secret: GossippedSecret,
changes: &mut Changes,
@@ -850,7 +862,7 @@ impl GossipMachine {
} else {
warn!("Received a m.secret.send event from an unknown device");
self.inner.store.mark_user_as_changed(&secret.event.sender).await?;
cache.mark_user_as_changed(&self.inner.store, &secret.event.sender).await?;
}
Ok(())
@@ -859,6 +871,7 @@ impl GossipMachine {
#[instrument(skip_all, fields(sender_key, sender = ?event.sender, request_id = ?event.content.request_id, secret_name))]
pub async fn receive_secret_event(
&self,
cache: &StoreCache,
sender_key: Curve25519PublicKey,
event: &DecryptedSecretSendEvent,
changes: &mut Changes,
@@ -887,7 +900,7 @@ impl GossipMachine {
gossip_request: request,
};
self.receive_secret(sender_key, secret, changes).await?;
self.receive_secret(cache, sender_key, secret, changes).await?;
Some(secret_name)
}
@@ -1597,7 +1610,11 @@ mod tests {
// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap();
{
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Now bob does have an outgoing request.
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
@@ -1675,7 +1692,10 @@ mod tests {
// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap();
{
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Now bob does have an outgoing request.
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
@@ -1771,13 +1791,19 @@ mod tests {
// No secret found
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
alice_machine.receive_incoming_secret_request(&event);
alice_machine.collect_incoming_key_requests().await.unwrap();
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
// No device found
alice_machine.inner.store.reset_cross_signing_identity().await;
alice_machine.receive_incoming_secret_request(&event);
alice_machine.collect_incoming_key_requests().await.unwrap();
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
alice_machine.inner.store.save_devices(&[bob_device]).await.unwrap();
@@ -1785,7 +1811,10 @@ mod tests {
// The device doesn't belong to us
alice_machine.inner.store.reset_cross_signing_identity().await;
alice_machine.receive_incoming_secret_request(&event);
alice_machine.collect_incoming_key_requests().await.unwrap();
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
let event = RumaToDeviceEvent {
@@ -1799,7 +1828,10 @@ mod tests {
// The device isn't trusted
alice_machine.receive_incoming_secret_request(&event);
alice_machine.collect_incoming_key_requests().await.unwrap();
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
// We need a trusted device, otherwise we won't serve secrets
@@ -1807,13 +1839,16 @@ mod tests {
alice_machine.inner.store.save_devices(&[alice_device.clone()]).await.unwrap();
alice_machine.receive_incoming_secret_request(&event);
alice_machine.collect_incoming_key_requests().await.unwrap();
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(!alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
}
#[async_test]
#[cfg(feature = "backups_v1")]
async fn secret_broadcasting() {
async fn test_secret_broadcasting() {
use futures_util::{pin_mut, FutureExt};
use ruma::api::client::to_device::send_event_to_device::v3::Response as ToDeviceResponse;
use serde_json::value::to_raw_value;
@@ -1876,7 +1911,15 @@ mod tests {
.await
.unwrap();
alice_machine.inner.key_request_machine.receive_incoming_secret_request(&event);
alice_machine.inner.key_request_machine.collect_incoming_key_requests().await.unwrap();
{
let alice_cache = alice_machine.store().cache().await.unwrap();
alice_machine
.inner
.key_request_machine
.collect_incoming_key_requests(&alice_cache)
.await
.unwrap();
}
let requests =
alice_machine.inner.key_request_machine.outgoing_to_device_requests().await.unwrap();
@@ -1929,7 +1972,10 @@ mod tests {
// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
bob_machine.collect_incoming_key_requests().await.unwrap();
{
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Bob only has a keys claim request, since we're lacking a session
assert_eq!(bob_machine.outgoing_to_device_requests().await.unwrap().len(), 1);
assert_matches!(
@@ -1964,7 +2010,10 @@ mod tests {
bob_machine.retry_keyshare(alice_id(), alice_device_id());
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
bob_machine.collect_incoming_key_requests().await.unwrap();
{
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Bob now has an outgoing requests.
assert!(!bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.inner.wait_queue.is_empty());

View File

@@ -798,7 +798,7 @@ impl IdentityManager {
// The check for emptiness is done first for performance.
let (users, sequence_number) =
if users.is_empty() && !self.store.tracked_users().await?.contains(self.user_id()) {
self.store.mark_user_as_changed(self.user_id()).await?;
self.store.cache().await?.mark_user_as_changed(&self.store, self.user_id()).await?;
self.store.users_for_key_query().await?
} else {
(users, sequence_number)
@@ -1314,7 +1314,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn private_identity_invalidation_after_public_keys_change() {
async fn test_private_identity_invalidation_after_public_keys_change() {
let user_id = user_id!("@example1:localhost");
let manager = manager_test_helper(user_id, "DEVICEID".into()).await;
@@ -1403,7 +1403,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn no_tracked_users_key_query_request() {
async fn test_no_tracked_users_key_query_request() {
let manager = manager_test_helper(user_id(), device_id()).await;
assert!(
@@ -1458,7 +1458,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn failure_handling() {
async fn test_failure_handling() {
let manager = manager_test_helper(user_id(), device_id()).await;
let alice = user_id!("@alice:example.org");
@@ -1466,7 +1466,11 @@ pub(crate) mod tests {
manager.store.tracked_users().await.unwrap().is_empty(),
"No users are initially tracked"
);
manager.store.mark_user_as_changed(alice).await.unwrap();
{
let cache = manager.store.cache().await.unwrap();
cache.mark_user_as_changed(&manager.store, alice).await.unwrap();
}
assert!(
manager.store.tracked_users().await.unwrap().contains(alice),
@@ -1517,7 +1521,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn devices_stream() {
async fn test_devices_stream() {
let manager = manager_test_helper(user_id(), device_id()).await;
let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]);
@@ -1531,7 +1535,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn identities_stream() {
async fn test_identities_stream() {
let manager = manager_test_helper(user_id(), device_id()).await;
let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]);
@@ -1545,7 +1549,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn identities_stream_raw() {
async fn test_identities_stream_raw() {
let mut manager = Some(manager_test_helper(user_id(), device_id()).await);
let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![user_id()]);
@@ -1590,7 +1594,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn identities_stream_raw_signature_update() {
async fn test_identities_stream_raw_signature_update() {
let mut manager = Some(manager_test_helper(user_id(), device_id()).await);
let (request_id, _) =
manager.as_ref().unwrap().build_key_query_for_users(vec![other_user_id()]);

View File

@@ -70,7 +70,7 @@ use crate::{
session_manager::{GroupSessionManager, SessionManager},
store::{
Changes, CryptoStoreWrapper, DeviceChanges, IdentityChanges, IntoCryptoStore, MemoryStore,
PendingChanges, Result as StoreResult, RoomKeyInfo, SecretImportError, Store,
PendingChanges, Result as StoreResult, RoomKeyInfo, SecretImportError, Store, StoreCache,
StoreTransaction,
},
types::{
@@ -656,7 +656,7 @@ impl OlmMachine {
// Handle the decrypted event, e.g. fetch out Megolm sessions out of
// the event.
self.handle_decrypted_to_device_event(&mut decrypted, changes).await?;
self.handle_decrypted_to_device_event(transaction.cache(), &mut decrypted, changes).await?;
Ok(decrypted)
}
@@ -896,6 +896,7 @@ impl OlmMachine {
)]
async fn handle_decrypted_to_device_event(
&self,
cache: &StoreCache,
decrypted: &mut OlmDecryptionInfo,
changes: &mut Changes,
) -> OlmResult<()> {
@@ -918,7 +919,7 @@ impl OlmMachine {
let name = self
.inner
.key_request_machine
.receive_secret_event(decrypted.result.sender_key, e, changes)
.receive_secret_event(cache, decrypted.result.sender_key, e, changes)
.await?;
// Set the secret name so other consumers of the event know
@@ -1186,8 +1187,11 @@ impl OlmMachine {
events.push(raw_event);
}
let changed_sessions =
self.inner.key_request_machine.collect_incoming_key_requests().await?;
let changed_sessions = self
.inner
.key_request_machine
.collect_incoming_key_requests(transaction.cache())
.await?;
changes.sessions.extend(changed_sessions);
changes.next_batch_token = sync_changes.next_batch_token;
@@ -2208,7 +2212,7 @@ pub(crate) mod tests {
(machine, keys)
}
async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) {
async fn get_machine_after_query_test_helper() -> (OlmMachine, OneTimeKeys) {
let (machine, otk) = get_prepared_machine_test_helper(user_id(), false).await;
let response = keys_query_response();
let req_id = TransactionId::new();
@@ -2537,7 +2541,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_missing_sessions_calculation() {
let (machine, _) = get_machine_after_query().await;
let (machine, _) = get_machine_after_query_test_helper().await;
let alice = alice_id();
let alice_device = alice_device_id();
@@ -2815,7 +2819,7 @@ pub(crate) mod tests {
async fn test_request_missing_secrets_cross_signed() {
let (alice, bob) = get_machine_pair_with_session(alice_id(), bob_id(), false).await;
setup_cross_signing_for_machine(&alice, &bob).await;
setup_cross_signing_for_machine_test_helper(&alice, &bob).await;
let should_query_secrets = alice.query_missing_secrets_from_other_sessions().await.unwrap();
@@ -3092,7 +3096,7 @@ pub(crate) mod tests {
);
assert_shield!(encryption_info, Red, Red);
setup_cross_signing_for_machine(&alice, &bob).await;
setup_cross_signing_for_machine_test_helper(&alice, &bob).await;
let bob_id_from_alice = alice.get_identity(bob.user_id(), None).await.unwrap();
assert_matches!(bob_id_from_alice, Some(UserIdentities::Other(_)));
let alice_id_from_bob = bob.get_identity(alice.user_id(), None).await.unwrap();
@@ -3173,7 +3177,7 @@ pub(crate) mod tests {
);
}
async fn setup_cross_signing_for_machine(alice: &OlmMachine, bob: &OlmMachine) {
async fn setup_cross_signing_for_machine_test_helper(alice: &OlmMachine, bob: &OlmMachine) {
let (alice_upload_signing, _) =
alice.bootstrap_cross_signing(false).await.expect("Expect Alice x-signing key request");
@@ -3517,7 +3521,7 @@ pub(crate) mod tests {
}
#[async_test]
async fn interactive_verification() {
async fn test_interactive_verification() {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
@@ -3913,7 +3917,7 @@ pub(crate) mod tests {
let (alice, bob) =
get_machine_pair_with_setup_sessions_test_helper(alice_id(), user_id(), false).await;
setup_cross_signing_for_machine(&alice, &bob).await;
setup_cross_signing_for_machine_test_helper(&alice, &bob).await;
let second_alice = create_additional_machine(&alice).await;

View File

@@ -969,7 +969,7 @@ mod tests {
.expect("Can't parse the keys claim response")
}
async fn machine_with_user(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
let keys_query = keys_query_response();
let keys_claim = keys_claim_response();
let txn_id = TransactionId::new();
@@ -985,10 +985,10 @@ mod tests {
}
async fn machine() -> OlmMachine {
machine_with_user(alice_id(), alice_device_id()).await
machine_with_user_test_helper(alice_id(), alice_device_id()).await
}
async fn machine_with_shared_room_key() -> OlmMachine {
async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
let machine = machine().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
@@ -1119,7 +1119,7 @@ mod tests {
#[async_test]
async fn ratcheted_sharing() {
let machine = machine_with_shared_room_key().await;
let machine = machine_with_shared_room_key_test_helper().await;
let room_id = room_id!("!test:localhost");
let late_joiner = user_id!("@bob:localhost");
@@ -1147,7 +1147,7 @@ mod tests {
#[async_test]
async fn changing_encryption_settings() {
let machine = machine_with_shared_room_key().await;
let machine = machine_with_shared_room_key_test_helper().await;
let room_id = room_id!("!test:localhost");
let keys_claim = keys_claim_response();
@@ -1201,7 +1201,7 @@ mod tests {
let device_id = device_id!("TESTDEVICE");
let room_id = room_id!("!test:localhost");
let machine = machine_with_user(user_id, device_id).await;
let machine = machine_with_user_test_helper(user_id, device_id).await;
let (outbound, _) = machine
.inner
@@ -1343,7 +1343,7 @@ mod tests {
}
#[async_test]
async fn no_olm_withheld_only_sent_once() {
async fn test_no_olm_withheld_only_sent_once() {
let keys_query = keys_query_response();
let txn_id = TransactionId::new();

View File

@@ -438,7 +438,8 @@ impl SessionManager {
}
}
match self.key_request_machine.collect_incoming_key_requests().await {
let store_cache = self.store.cache().await?;
match self.key_request_machine.collect_incoming_key_requests(&store_cache).await {
Ok(sessions) => {
let changes = Changes { sessions, ..Default::default() };
self.store.save_changes(changes).await?

View File

@@ -206,6 +206,18 @@ impl StoreCache {
store.inner.store.save_tracked_users(&store_updates).await
}
/// Mark the given user as being tracked for device lists, and mark that it
/// has an outdated device list.
///
/// This means that the user will be considered for a `/keys/query` request
/// next time [`Store::users_for_key_query()`] is called.
pub(crate) async fn mark_user_as_changed(&self, store: &Store, user: &UserId) -> Result<()> {
store.inner.users_for_key_query.lock().await.insert_user(user);
self.tracked_users.write().unwrap().insert(user.to_owned());
store.inner.store.save_tracked_users(&[(user, true)]).await
}
}
pub(crate) struct StoreCacheGuard {
@@ -1086,18 +1098,6 @@ impl Store {
Ok(())
}
/// Mark the given user as being tracked for device lists, and mark that it
/// has an outdated device list.
///
/// This means that the user will be considered for a `/keys/query` request
/// next time [`Store::users_for_key_query()`] is called.
pub(crate) async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
self.inner.users_for_key_query.lock().await.insert_user(user);
self.cache().await?.tracked_users.write().unwrap().insert(user.to_owned());
self.inner.store.save_tracked_users(&[(user, true)]).await
}
/// Add entries to the list of users being tracked for device changes
///
/// Any users not already on the list are flagged as awaiting a key query.