diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index eeee666fe..2fb6f2182 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -94,6 +94,10 @@ use crate::{ /// Matrix end to end encryption. #[derive(Clone)] pub struct OlmMachine { + pub(crate) inner: Arc, +} + +pub struct OlmMachineInner { /// The unique user id that owns this account. user_id: OwnedUserId, /// The unique device ID of the device that holds this account. @@ -131,8 +135,8 @@ pub struct OlmMachine { impl std::fmt::Debug for OlmMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OlmMachine") - .field("user_id", &self.user_id) - .field("device_id", &self.device_id) + .field("user_id", &self.user_id()) + .field("device_id", &self.device_id()) .finish() } } @@ -195,7 +199,7 @@ impl OlmMachine { #[cfg(feature = "backups_v1")] let backup_machine = BackupMachine::new(account.clone(), store.clone(), None); - OlmMachine { + let inner = Arc::new(OlmMachineInner { user_id, device_id, account, @@ -208,7 +212,9 @@ impl OlmMachine { identity_manager, #[cfg(feature = "backups_v1")] backup_machine, - } + }); + + Self { inner } } /// Create a new OlmMachine with the given [`CryptoStore`]. @@ -299,27 +305,27 @@ impl OlmMachine { /// Get the crypto store associated with this `OlmMachine` instance. pub fn store(&self) -> &Store { - &self.store + &self.inner.store } /// The unique user id that owns this `OlmMachine` instance. pub fn user_id(&self) -> &UserId { - &self.user_id + &self.inner.user_id } /// The unique device ID that identifies this `OlmMachine`. pub fn device_id(&self) -> &DeviceId { - &self.device_id + &self.inner.device_id } /// Get the public parts of our Olm identity keys. pub fn identity_keys(&self) -> IdentityKeys { - self.account.identity_keys() + self.inner.account.identity_keys() } /// Get the display name of our own device pub async fn display_name(&self) -> StoreResult> { - self.store.device_display_name().await + self.store().device_display_name().await } /// Get the list of "tracked users". @@ -327,7 +333,7 @@ impl OlmMachine { /// See [`update_tracked_users`](#method.update_tracked_users) for more /// information. pub async fn tracked_users(&self) -> StoreResult> { - self.store.tracked_users().await + self.store().tracked_users().await } /// Enable or disable room key forwarding. @@ -337,12 +343,12 @@ impl OlmMachine { /// events. #[cfg(feature = "automatic-room-key-forwarding")] pub fn toggle_room_key_forwarding(&self, enable: bool) { - self.key_request_machine.toggle_room_key_forwarding(enable) + self.inner.key_request_machine.toggle_room_key_forwarding(enable) } /// Is room key forwarding enabled? pub fn is_room_key_forwarding_enabled(&self) -> bool { - self.key_request_machine.is_room_key_forwarding_enabled() + self.inner.key_request_machine.is_room_key_forwarding_enabled() } /// Get the outgoing requests that need to be sent out. @@ -359,11 +365,12 @@ impl OlmMachine { request_id: TransactionId::new(), request: Arc::new(r.into()), }) { - self.account.save().await?; + self.inner.account.save().await?; requests.push(r); } for request in self + .inner .identity_manager .users_for_key_query() .await? @@ -373,8 +380,8 @@ impl OlmMachine { requests.push(request); } - requests.append(&mut self.verification_machine.outgoing_messages()); - requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?); + requests.append(&mut self.inner.verification_machine.outgoing_messages()); + requests.append(&mut self.inner.key_request_machine.outgoing_to_device_requests().await?); Ok(requests) } @@ -410,14 +417,14 @@ impl OlmMachine { self.receive_cross_signing_upload_response().await?; } IncomingResponse::SignatureUpload(_) => { - self.verification_machine.mark_request_as_sent(request_id); + self.inner.verification_machine.mark_request_as_sent(request_id); } IncomingResponse::RoomMessage(_) => { - self.verification_machine.mark_request_as_sent(request_id); + self.inner.verification_machine.mark_request_as_sent(request_id); } IncomingResponse::KeysBackup(_) => { #[cfg(feature = "backups_v1")] - self.backup_machine.mark_request_as_sent(request_id).await?; + self.inner.backup_machine.mark_request_as_sent(request_id).await?; } }; @@ -426,12 +433,12 @@ impl OlmMachine { /// Mark the cross signing identity as shared. async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> { - let identity = self.user_identity.lock().await; + let identity = self.inner.user_identity.lock().await; identity.mark_as_shared(); let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() }; - self.store.save_changes(changes).await + self.store().save_changes(changes).await } /// Create a new cross signing identity and get the upload request to push @@ -454,11 +461,12 @@ impl OlmMachine { &self, reset: bool, ) -> StoreResult<(UploadSigningKeysRequest, UploadSignaturesRequest)> { - let mut identity = self.user_identity.lock().await; + let mut identity = self.inner.user_identity.lock().await; if identity.is_empty().await || reset { info!("Creating new cross signing identity"); - let (id, request, signature_request) = self.account.bootstrap_cross_signing().await; + let (id, request, signature_request) = + self.inner.account.bootstrap_cross_signing().await; *identity = id; @@ -472,7 +480,7 @@ impl OlmMachine { ..Default::default() }; - self.store.save_changes(changes).await?; + self.store().save_changes(changes).await?; Ok((request, signature_request)) } else { @@ -480,7 +488,7 @@ impl OlmMachine { let request = identity.as_upload_request().await; // TODO remove this expect. let signature_request = - identity.sign_account(&self.account).await.expect("Can't sign device keys"); + identity.sign_account(&self.inner.account).await.expect("Can't sign device keys"); Ok((request, signature_request)) } } @@ -489,7 +497,7 @@ impl OlmMachine { #[cfg(any(test, feature = "testing"))] #[allow(dead_code)] pub(crate) fn account(&self) -> &ReadOnlyAccount { - &self.account + &self.inner.account } /// Receive a successful keys upload response. @@ -502,7 +510,7 @@ impl OlmMachine { &self, response: &upload_keys::v3::Response, ) -> OlmResult<()> { - self.account.receive_keys_upload_response(response).await + self.inner.account.receive_keys_upload_response(response).await } /// Get the a key claiming request for the user/device pairs that we are @@ -536,7 +544,7 @@ impl OlmMachine { &self, users: impl Iterator, ) -> StoreResult> { - self.session_manager.get_missing_sessions(users).await + self.inner.session_manager.get_missing_sessions(users).await } /// Receive a successful key claim response and create new Olm sessions with @@ -546,7 +554,7 @@ impl OlmMachine { /// /// * `response` - The response containing the claimed one-time keys. async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { - self.session_manager.receive_keys_claim_response(response).await + self.inner.session_manager.receive_keys_claim_response(response).await } /// Receive a successful keys query response. @@ -563,7 +571,7 @@ impl OlmMachine { request_id: &TransactionId, response: &KeysQueryResponse, ) -> OlmResult<(DeviceChanges, IdentityChanges)> { - self.identity_manager.receive_keys_query_response(request_id, response).await + self.inner.identity_manager.receive_keys_query_response(request_id, response).await } /// Get a request to upload E2EE keys to the server. @@ -576,7 +584,8 @@ impl OlmMachine { /// [`receive_keys_upload_response`]: #method.receive_keys_upload_response /// [`OlmMachine`]: struct.OlmMachine.html async fn keys_for_upload(&self) -> Option { - let (device_keys, one_time_keys, fallback_keys) = self.account.keys_for_upload().await; + let (device_keys, one_time_keys, fallback_keys) = + self.inner.account.keys_for_upload().await; if device_keys.is_none() && one_time_keys.is_empty() && fallback_keys.is_empty() { None @@ -601,7 +610,7 @@ impl OlmMachine { &self, event: &EncryptedToDeviceEvent, ) -> OlmResult { - let mut decrypted = self.account.decrypt_to_device_event(event).await?; + let mut decrypted = self.inner.account.decrypt_to_device_event(event).await?; // Handle the decrypted event, e.g. fetch out Megolm sessions out of // the event. self.handle_decrypted_to_device_event(&mut decrypted).await?; @@ -635,7 +644,7 @@ impl OlmMachine { Ok(session) => { tracing::Span::current().record("session_id", session.session_id()); - if self.store.compare_group_session(&session).await? == SessionOrdering::Better { + if self.store().compare_group_session(&session).await? == SessionOrdering::Better { info!("Received a new megolm room key"); Ok(Some(session)) @@ -699,11 +708,12 @@ impl OlmMachine { room_id: &RoomId, ) -> OlmResult<()> { let (_, session) = self + .inner .group_session_manager .create_outbound_group_session(room_id, EncryptionSettings::default()) .await?; - self.store.save_inbound_group_sessions(&[session]).await?; + self.store().save_inbound_group_sessions(&[session]).await?; Ok(()) } @@ -715,6 +725,7 @@ impl OlmMachine { room_id: &RoomId, ) -> OlmResult { let (_, session) = self + .inner .group_session_manager .create_outbound_group_session(room_id, EncryptionSettings::default()) .await?; @@ -773,7 +784,7 @@ impl OlmMachine { content: Value, event_type: &str, ) -> MegolmResult> { - self.group_session_manager.encrypt(room_id, content, event_type).await + self.inner.group_session_manager.encrypt(room_id, content, event_type).await } /// Invalidate the currently active outbound group session for the given @@ -782,7 +793,7 @@ impl OlmMachine { /// Returns true if a session was invalidated, false if there was no session /// to invalidate. pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult { - self.group_session_manager.invalidate_group_session(room_id).await + self.inner.group_session_manager.invalidate_group_session(room_id).await } /// Get to-device requests to share a room key with users in a room. @@ -802,7 +813,7 @@ impl OlmMachine { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult>> { - self.group_session_manager.share_room_key(room_id, users, encryption_settings).await + self.inner.group_session_manager.share_room_key(room_id, users, encryption_settings).await } /// Receive an unencrypted verification event. @@ -817,7 +828,7 @@ impl OlmMachine { &self, event: &AnyMessageLikeEvent, ) -> StoreResult<()> { - self.verification_machine.receive_any_event(event).await + self.inner.verification_machine.receive_any_event(event).await } /// Receive a verification event. @@ -825,7 +836,7 @@ impl OlmMachine { /// in rooms to the `OlmMachine`. The event should be in the decrypted form. /// in rooms to the `OlmMachine`. pub async fn receive_verification_event(&self, event: &AnyMessageLikeEvent) -> StoreResult<()> { - self.verification_machine.receive_any_event(event).await + self.inner.verification_machine.receive_any_event(event).await } /// Receive and properly handle a decrypted to-device event. @@ -853,6 +864,7 @@ impl OlmMachine { } AnyDecryptedOlmEvent::ForwardedRoomKey(e) => { let session = self + .inner .key_request_machine .receive_forwarded_room_key(decrypted.result.sender_key, e) .await?; @@ -860,6 +872,7 @@ impl OlmMachine { } AnyDecryptedOlmEvent::SecretSend(e) => { let name = self + .inner .key_request_machine .receive_secret_event(decrypted.result.sender_key, e) .await?; @@ -885,23 +898,23 @@ impl OlmMachine { } async fn handle_verification_event(&self, event: &ToDeviceEvents) { - if let Err(e) = self.verification_machine.receive_any_event(event).await { + if let Err(e) = self.inner.verification_machine.receive_any_event(event).await { error!("Error handling a verification event: {e:?}"); } } /// Mark an outgoing to-device requests as sent. async fn mark_to_device_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> { - self.verification_machine.mark_request_as_sent(request_id); - self.key_request_machine.mark_outgoing_request_as_sent(request_id).await?; - self.group_session_manager.mark_request_as_sent(request_id).await?; - self.session_manager.mark_outgoing_request_as_sent(request_id); + self.inner.verification_machine.mark_request_as_sent(request_id); + self.inner.key_request_machine.mark_outgoing_request_as_sent(request_id).await?; + self.inner.group_session_manager.mark_request_as_sent(request_id).await?; + self.inner.session_manager.mark_outgoing_request_as_sent(request_id); Ok(()) } /// Get a verification object for the given user id with the given flow id. pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { - self.verification_machine.get_verification(user_id, flow_id) + self.inner.verification_machine.get_verification(user_id, flow_id) } /// Get a verification request object with the given flow id. @@ -910,12 +923,12 @@ impl OlmMachine { user_id: &UserId, flow_id: impl AsRef, ) -> Option { - self.verification_machine.get_request(user_id, flow_id) + self.inner.verification_machine.get_request(user_id, flow_id) } /// Get all the verification requests of a given user. pub fn get_verification_requests(&self, user_id: &UserId) -> Vec { - self.verification_machine.get_requests(user_id) + self.inner.verification_machine.get_requests(user_id) } async fn update_key_counts( @@ -923,15 +936,15 @@ impl OlmMachine { one_time_key_count: &BTreeMap, unused_fallback_keys: Option<&[DeviceKeyAlgorithm]>, ) { - self.account.update_key_counts(one_time_key_count, unused_fallback_keys).await; + self.inner.account.update_key_counts(one_time_key_count, unused_fallback_keys).await; } async fn handle_to_device_event(&self, changes: &mut Changes, event: &ToDeviceEvents) { use crate::types::events::ToDeviceEvents::*; match event { - RoomKeyRequest(e) => self.key_request_machine.receive_incoming_key_request(e), - SecretRequest(e) => self.key_request_machine.receive_incoming_secret_request(e), + RoomKeyRequest(e) => self.inner.key_request_machine.receive_incoming_key_request(e), + SecretRequest(e) => self.inner.key_request_machine.receive_incoming_secret_request(e), RoomKeyWithheld(e) => self.add_withheld_info(changes, e).await, KeyVerificationAccept(..) | KeyVerificationCancel(..) @@ -999,8 +1012,11 @@ impl OlmMachine { Ok(e) => e, Err(err) => { if let OlmError::SessionWedged(sender, curve_key) = err { - if let Err(e) = - self.session_manager.mark_device_as_wedged(&sender, curve_key).await + if let Err(e) = self + .inner + .session_manager + .mark_device_as_wedged(&sender, curve_key) + .await { error!( error = ?e, @@ -1017,7 +1033,7 @@ impl OlmMachine { // one as well. match decrypted.session { SessionType::New(s) => { - changes.account = Some(self.account.inner.clone()); + changes.account = Some((*self.inner.account).clone()); changes.sessions.push(s); } SessionType::Existing(s) => { @@ -1081,16 +1097,17 @@ impl OlmMachine { unused_fallback_keys: Option<&[DeviceKeyAlgorithm]>, ) -> OlmResult>> { // Remove verification objects that have expired or are done. - let mut events = self.verification_machine.garbage_collect(); + let mut events = self.inner.verification_machine.garbage_collect(); // Always save the account, a new session might get created which also // touches the account. let mut changes = - Changes { account: Some(self.account.inner.clone()), ..Default::default() }; + Changes { account: Some((*self.inner.account).clone()), ..Default::default() }; self.update_key_counts(one_time_keys_counts, unused_fallback_keys).await; if let Err(e) = self + .inner .identity_manager .receive_device_changes(changed_devices.changed.iter().map(|u| u.as_ref())) .await @@ -1103,11 +1120,12 @@ impl OlmMachine { events.push(raw_event); } - let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?; + let changed_sessions = + self.inner.key_request_machine.collect_incoming_key_requests().await?; changes.sessions.extend(changed_sessions); - self.store.save_changes(changes).await?; + self.store().save_changes(changes).await?; Ok(events) } @@ -1134,7 +1152,7 @@ impl OlmMachine { room_id: &RoomId, ) -> MegolmResult<(Option, OutgoingRequest)> { let event = event.deserialize()?; - self.key_request_machine.request_key(room_id, &event).await + self.inner.key_request_machine.request_key(room_id, &event).await } async fn get_verification_state( @@ -1243,7 +1261,7 @@ impl OlmMachine { content: &SupportedEventEncryptionSchemes<'_>, ) -> MegolmResult { if let Some(session) = - self.store.get_inbound_group_session(room_id, content.session_id()).await? + self.store().get_inbound_group_session(room_id, content.session_id()).await? { tracing::Span::current().record("sender_key", session.sender_key().to_base64()); @@ -1262,6 +1280,7 @@ impl OlmMachine { error { let withheld_code = self + .inner .store .get_withheld_info(room_id, content.session_id()) .await? @@ -1280,6 +1299,7 @@ impl OlmMachine { } } else { let withheld_code = self + .inner .store .get_withheld_info(room_id, content.session_id()) .await? @@ -1329,7 +1349,10 @@ impl OlmMachine { // Maybe for some code there is no point MegolmError::MissingRoomKey(_) | MegolmError::Decryption(DecryptionError::UnknownMessageIndex(_, _)) => { - self.key_request_machine.create_outgoing_key_request(room_id, &event).await?; + self.inner + .key_request_machine + .create_outgoing_key_request(room_id, &event) + .await?; } _ => {} } @@ -1362,12 +1385,12 @@ impl OlmMachine { &self, users: impl IntoIterator, ) -> StoreResult<()> { - self.identity_manager.update_tracked_users(users).await + self.inner.identity_manager.update_tracked_users(users).await } async fn wait_if_user_pending(&self, user_id: &UserId, timeout: Option) { if let Some(timeout) = timeout { - self.store.wait_if_user_key_query_pending(timeout, user_id).await; + self.store().wait_if_user_key_query_pending(timeout, user_id).await; } } @@ -1408,7 +1431,7 @@ impl OlmMachine { timeout: Option, ) -> StoreResult> { self.wait_if_user_pending(user_id, timeout).await; - self.store.get_device(user_id, device_id).await + self.store().get_device(user_id, device_id).await } /// Get the cross signing user identity of a user. @@ -1430,7 +1453,7 @@ impl OlmMachine { timeout: Option, ) -> StoreResult> { self.wait_if_user_pending(user_id, timeout).await; - self.store.get_identity(user_id).await + self.store().get_identity(user_id).await } /// Get a map holding all the devices of an user. @@ -1466,7 +1489,7 @@ impl OlmMachine { timeout: Option, ) -> StoreResult { self.wait_if_user_pending(user_id, timeout).await; - self.store.get_user_devices(user_id).await + self.store().get_user_devices(user_id).await } /// Import the given room keys into our store. @@ -1525,6 +1548,7 @@ impl OlmMachine { match InboundGroupSession::from_export(&key) { Ok(session) => { let old_session = self + .inner .store .get_inbound_group_session(session.room_id(), session.session_id()) .await?; @@ -1564,7 +1588,7 @@ impl OlmMachine { let changes = Changes { inbound_group_sessions: sessions, ..Default::default() }; - self.store.save_changes(changes).await?; + self.store().save_changes(changes).await?; info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys"); @@ -1606,7 +1630,7 @@ impl OlmMachine { let mut exported = Vec::new(); let sessions: Vec = self - .store + .store() .get_inbound_group_sessions() .await? .into_iter() @@ -1626,7 +1650,7 @@ impl OlmMachine { /// This can be used to check which private cross signing keys we have /// stored locally. pub async fn cross_signing_status(&self) -> CrossSigningStatus { - self.user_identity.lock().await.status().await + self.inner.user_identity.lock().await.status().await } /// Export all the private cross signing keys we have. @@ -1637,11 +1661,11 @@ impl OlmMachine { /// This method returns `None` if we don't have any private cross signing /// keys. pub async fn export_cross_signing_keys(&self) -> Option { - let master_key = self.store.export_secret(&SecretName::CrossSigningMasterKey).await; + let master_key = self.store().export_secret(&SecretName::CrossSigningMasterKey).await; let self_signing_key = - self.store.export_secret(&SecretName::CrossSigningSelfSigningKey).await; + self.store().export_secret(&SecretName::CrossSigningSelfSigningKey).await; let user_signing_key = - self.store.export_secret(&SecretName::CrossSigningUserSigningKey).await; + self.store().export_secret(&SecretName::CrossSigningUserSigningKey).await; if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() { None @@ -1658,14 +1682,14 @@ impl OlmMachine { &self, export: CrossSigningKeyExport, ) -> Result { - self.store.import_cross_signing_keys(export).await + self.store().import_cross_signing_keys(export).await } async fn sign_with_master_key( &self, message: &str, ) -> Result<(OwnedDeviceKeyId, Ed25519Signature), SignatureError> { - let identity = &*self.user_identity.lock().await; + let identity = &*self.inner.user_identity.lock().await; let key_id = identity.master_key_id().await.ok_or(SignatureError::MissingSigningKey)?; let signature = identity.sign(message).await?; @@ -1681,8 +1705,8 @@ impl OlmMachine { pub async fn sign(&self, message: &str) -> Signatures { let mut signatures = Signatures::new(); - let key_id = self.account.signing_key_id(); - let signature = self.account.sign(message).await; + let key_id = self.inner.account.signing_key_id(); + let signature = self.inner.account.sign(message).await; signatures.add_signature(self.user_id().to_owned(), key_id, signature); match self.sign_with_master_key(message).await { @@ -1703,7 +1727,7 @@ impl OlmMachine { /// the server. #[cfg(feature = "backups_v1")] pub fn backup_machine(&self) -> &BackupMachine { - &self.backup_machine + &self.inner.backup_machine } } @@ -1833,7 +1857,7 @@ pub(crate) mod tests { pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { let machine = OlmMachine::new(user_id(), bob_device_id()).await; - machine.account.inner.update_uploaded_key_count(0); + machine.account().update_uploaded_key_count(0); let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload"); let response = keys_upload_response(); machine.receive_keys_upload_response(&response).await.unwrap(); @@ -1860,8 +1884,8 @@ pub(crate) mod tests { let alice_device = ReadOnlyDevice::from_machine(&alice).await; let bob_device = ReadOnlyDevice::from_machine(&bob).await; - alice.store.save_devices(&[bob_device]).await.unwrap(); - bob.store.save_devices(&[alice_device]).await.unwrap(); + alice.store().save_devices(&[bob_device]).await.unwrap(); + bob.store().save_devices(&[alice_device]).await.unwrap(); (alice, bob, otk) } @@ -1890,19 +1914,19 @@ pub(crate) mod tests { let (alice, bob) = get_machine_pair_with_session().await; let bob_device = - alice.get_device(&bob.user_id, &bob.device_id, None).await.unwrap().unwrap(); + alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap(); let (session, content) = bob_device .encrypt("m.dummy", serde_json::to_value(ToDeviceDummyEventContent::new()).unwrap()) .await .unwrap(); - alice.store.save_sessions(&[session]).await.unwrap(); + alice.store().save_sessions(&[session]).await.unwrap(); let event = ToDeviceEvent::new(alice.user_id().to_owned(), content.deserialize_as().unwrap()); let decrypted = bob.decrypt_to_device_event(&event).await.unwrap(); - bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap(); + bob.store().save_sessions(&[decrypted.session.session()]).await.unwrap(); (alice, bob) } @@ -1925,28 +1949,28 @@ pub(crate) mod tests { async fn generate_one_time_keys() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - assert!(machine.account.generate_one_time_keys().await.is_some()); + assert!(machine.account().generate_one_time_keys().await.is_some()); let mut response = keys_upload_response(); machine.receive_keys_upload_response(&response).await.unwrap(); - assert!(machine.account.generate_one_time_keys().await.is_some()); + assert!(machine.account().generate_one_time_keys().await.is_some()); response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50)); machine.receive_keys_upload_response(&response).await.unwrap(); - assert!(machine.account.generate_one_time_keys().await.is_none()); + assert!(machine.account().generate_one_time_keys().await.is_none()); } #[async_test] async fn test_device_key_signing() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - let device_keys = machine.account.device_keys().await; - let identity_keys = machine.account.identity_keys(); + let device_keys = machine.account().device_keys().await; + let identity_keys = machine.account().identity_keys(); let ed25519_key = identity_keys.ed25519; let ret = ed25519_key.verify_json( - &machine.user_id, + machine.user_id(), &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &device_keys, ); @@ -1959,11 +1983,12 @@ pub(crate) mod tests { let room_id = room_id!("!test:example.org"); machine.create_outbound_group_session_with_defaults(room_id).await.unwrap(); - assert!(machine.group_session_manager.get_outbound_group_session(room_id).is_some()); + assert!(machine.inner.group_session_manager.get_outbound_group_session(room_id).is_some()); machine.invalidate_group_session(room_id).await.unwrap(); assert!(machine + .inner .group_session_manager .get_outbound_group_session(room_id) .unwrap() @@ -1974,12 +1999,12 @@ pub(crate) mod tests { async fn test_invalid_signature() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - let device_keys = machine.account.device_keys().await; + let device_keys = machine.account().device_keys().await; let key = Ed25519PublicKey::from_slice(&[0u8; 32]).unwrap(); let ret = key.verify_json( - &machine.user_id, + machine.user_id(), &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &device_keys, ); @@ -1989,10 +2014,10 @@ pub(crate) mod tests { #[async_test] async fn one_time_key_signing() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - machine.account.inner.update_uploaded_key_count(49); + machine.account().update_uploaded_key_count(49); - let mut one_time_keys = machine.account.signed_one_time_keys().await; - let ed25519_key = machine.account.identity_keys().ed25519; + let mut one_time_keys = machine.account().signed_one_time_keys().await; + let ed25519_key = machine.account().identity_keys().ed25519; let one_time_key: SignedKey = one_time_keys .values_mut() @@ -2003,7 +2028,7 @@ pub(crate) mod tests { ed25519_key .verify_json( - &machine.user_id, + machine.user_id(), &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &one_time_key, ) @@ -2013,9 +2038,9 @@ pub(crate) mod tests { #[async_test] async fn test_keys_for_upload() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; - machine.account.inner.update_uploaded_key_count(0); + machine.account().update_uploaded_key_count(0); - let ed25519_key = machine.account.identity_keys().ed25519; + let ed25519_key = machine.account().identity_keys().ed25519; let mut request = machine.keys_for_upload().await.expect("Can't prepare initial key upload"); @@ -2029,7 +2054,7 @@ pub(crate) mod tests { .unwrap(); let ret = ed25519_key.verify_json( - &machine.user_id, + machine.user_id(), &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &one_time_key, ); @@ -2038,7 +2063,7 @@ pub(crate) mod tests { let device_keys: DeviceKeys = request.device_keys.unwrap().deserialize_as().unwrap(); let ret = ed25519_key.verify_json( - &machine.user_id, + machine.user_id(), &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()), &device_keys, ); @@ -2063,13 +2088,13 @@ pub(crate) mod tests { let alice_id = user_id!("@alice:example.org"); let alice_device_id: &DeviceId = device_id!("JLAFKJWSCS"); - let alice_devices = machine.store.get_user_devices(alice_id).await.unwrap(); + let alice_devices = machine.store().get_user_devices(alice_id).await.unwrap(); assert!(alice_devices.devices().peekable().peek().is_none()); let req_id = TransactionId::new(); machine.receive_keys_query_response(&req_id, &response).await.unwrap(); - let device = machine.store.get_device(alice_id, alice_device_id).await.unwrap().unwrap(); + let device = machine.store().get_device(alice_id, alice_device_id).await.unwrap().unwrap(); assert_eq!(device.user_id(), alice_id); assert_eq!(device.device_id(), alice_device_id); } @@ -2119,8 +2144,8 @@ pub(crate) mod tests { .await; let session = alice_machine - .store - .get_sessions(&bob_machine.account.identity_keys().curve25519.to_base64()) + .store() + .get_sessions(&bob_machine.account().identity_keys().curve25519.to_base64()) .await .unwrap() .unwrap(); @@ -2168,7 +2193,7 @@ pub(crate) mod tests { // with the same timestamps, so we manually masage them to be 10s apart. let session_id = { let sessions = alice_machine - .store + .store() .get_sessions(&bob_machine.identity_keys().curve25519.to_base64()) .await .unwrap() @@ -2207,7 +2232,7 @@ pub(crate) mod tests { let (alice, bob) = get_machine_pair_with_session().await; let bob_device = - alice.get_device(&bob.user_id, &bob.device_id, None).await.unwrap().unwrap(); + alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap(); let event = ToDeviceEvent::new( alice.user_id().to_owned(), @@ -2254,7 +2279,7 @@ pub(crate) mod tests { let event = json_convert(&event).unwrap(); let alice_session = - alice.group_session_manager.get_outbound_group_session(room_id).unwrap(); + alice.inner.group_session_manager.get_outbound_group_session(room_id).unwrap(); let decrypted = bob .receive_sync_changes(vec![event], &Default::default(), &Default::default(), None) @@ -2271,7 +2296,7 @@ pub(crate) mod tests { } let session = - bob.store.get_inbound_group_session(room_id, alice_session.session_id()).await; + bob.store().get_inbound_group_session(room_id, alice_session.session_id()).await; assert!(session.unwrap().is_some()); } @@ -2295,7 +2320,7 @@ pub(crate) mod tests { let group_session = bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session.unwrap(); - bob.store.save_inbound_group_sessions(&[group_session.clone()]).await.unwrap(); + bob.store().save_inbound_group_sessions(&[group_session.clone()]).await.unwrap(); // when we decrypt the room key, the // inbound_group_session_streamroom_keys_received_stream should tell us @@ -2441,7 +2466,7 @@ pub(crate) mod tests { let export = group_session.as_ref().unwrap().clone().export().await; - bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap(); + bob.store().save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap(); let plaintext = "It is a secret to everybody"; @@ -2525,7 +2550,7 @@ pub(crate) mod tests { // Simulate an imported session, to change verification state let imported = InboundGroupSession::from_export(&export).unwrap(); - bob.store.save_inbound_group_sessions(&[imported]).await.unwrap(); + bob.store().save_inbound_group_sessions(&[imported]).await.unwrap(); let encryption_info = bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); @@ -2751,7 +2776,7 @@ pub(crate) mod tests { assert_eq!(device_change.new.len(), 2); assert_eq!(identity_change.new.len(), 1); // - let devices = bob.store.get_user_devices(other_user_id).await.unwrap(); + let devices = bob.store().get_user_devices(other_user_id).await.unwrap(); assert_eq!(devices.devices().count(), 2); let fake_room_id = room_id!("!roomid:example.com"); @@ -2760,7 +2785,7 @@ pub(crate) mod tests { // We will use the export to create various inbounds with other claimed // ownership let id_keys = bob.identity_keys(); - let fake_device_id = bob.device_id.clone(); + let fake_device_id = bob.device_id().into(); let olm = OutboundGroupSession::new( fake_device_id, Arc::new(id_keys), @@ -2816,7 +2841,7 @@ pub(crate) mod tests { let bob_other_device = device_id!("OTHERBOB"); let bob_other_machine = OlmMachine::new(bob_id, bob_other_device).await; let bob_other_device = ReadOnlyDevice::from_machine(&bob_other_machine).await; - bob.store.save_devices(&[bob_other_device]).await.unwrap(); + bob.store().save_devices(&[bob_other_device]).await.unwrap(); bob.get_device(bob_id, device_id!("OTHERBOB"), None) .await .unwrap() @@ -2855,7 +2880,7 @@ pub(crate) mod tests { let group_session = bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session; - bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap(); + bob.store().save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap(); let room_event = json_convert(&room_event).unwrap(); @@ -2865,7 +2890,7 @@ pub(crate) mod tests { if let vodozemac::megolm::DecryptionError::UnknownMessageIndex(_, _) = vodo_error { // check that key has been requested let outgoing_to_devices = - bob.key_request_machine.outgoing_to_device_requests().await.unwrap(); + bob.inner.key_request_machine.outgoing_to_device_requests().await.unwrap(); assert_eq!(1, outgoing_to_devices.len()); } else { panic!("Should be UnknownMessageIndex error ") @@ -2903,6 +2928,7 @@ pub(crate) mod tests { alice.handle_verification_event(&event).await; let (event, request_id) = alice + .inner .verification_machine .outgoing_messages() .first() @@ -2912,6 +2938,7 @@ pub(crate) mod tests { bob.handle_verification_event(&event).await; let (event, request_id) = bob + .inner .verification_machine .outgoing_messages() .first() @@ -3026,11 +3053,11 @@ pub(crate) mod tests { alice.handle_verification_event(&event).await; // Alice sends a key - let msgs = alice.verification_machine.outgoing_messages(); + let msgs = alice.inner.verification_machine.outgoing_messages(); assert!(msgs.len() == 1); let msg = msgs.first().unwrap(); let event = outgoing_request_to_event(alice.user_id(), msg); - alice.verification_machine.mark_request_as_sent(&msg.request_id); + alice.inner.verification_machine.mark_request_as_sent(&msg.request_id); // ---------------------------------------------------------------------------- // On Bob's device: @@ -3039,11 +3066,11 @@ pub(crate) mod tests { bob.handle_verification_event(&event).await; // Now bob sends a key - let msgs = bob.verification_machine.outgoing_messages(); + let msgs = bob.inner.verification_machine.outgoing_messages(); assert!(msgs.len() == 1); let msg = msgs.first().unwrap(); let event = outgoing_request_to_event(bob.user_id(), msg); - bob.verification_machine.mark_request_as_sent(&msg.request_id); + bob.inner.verification_machine.mark_request_as_sent(&msg.request_id); // ---------------------------------------------------------------------------- // On Alice's device: @@ -3090,7 +3117,7 @@ pub(crate) mod tests { bob.handle_verification_event(&event_mac).await; // Bob verifies that the MAC is valid and also sends a "done" message. - let msgs = bob.verification_machine.outgoing_messages(); + let msgs = bob.inner.verification_machine.outgoing_messages(); eprintln!("{msgs:?}"); assert!(msgs.len() == 1); let event = msgs.first().map(|r| outgoing_request_to_event(bob.user_id(), r)).unwrap(); @@ -3174,7 +3201,7 @@ pub(crate) mod tests { bob.receive_sync_changes(vec![event], &changed_devices, &key_counts, None).await.unwrap(); - let session = bob.store.get_inbound_group_session(room_id, &session_id).await; + let session = bob.store().get_inbound_group_session(room_id, &session_id).await; assert!(session.unwrap().is_none()); } @@ -3184,10 +3211,10 @@ pub(crate) mod tests { let room_id = room_id!("!test:localhost"); let (alice, _) = get_machine_pair_with_setup_sessions().await; let device = ReadOnlyDevice::from_machine(&alice).await; - alice.store.save_devices(&[device]).await.unwrap(); + alice.store().save_devices(&[device]).await.unwrap(); let (outbound, mut inbound) = - alice.account.create_group_session_pair(room_id, Default::default()).await.unwrap(); + alice.account().create_group_session_pair(room_id, Default::default()).await.unwrap(); let fake_key = Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE") .unwrap() @@ -3197,7 +3224,7 @@ pub(crate) mod tests { let content = json!({}); let content = outbound.encrypt(content, "m.dummy").await; - alice.store.save_inbound_group_sessions(&[inbound]).await.unwrap(); + alice.store().save_inbound_group_sessions(&[inbound]).await.unwrap(); let event = json!({ "sender": alice.user_id(), diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index b170b14e9..3d5834818 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -926,7 +926,8 @@ mod tests { let requests = machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap(); - let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap(); + let outbound = + machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap(); assert!(!outbound.pending_requests().is_empty()); assert!(!outbound.shared()); @@ -1066,7 +1067,8 @@ mod tests { .filter(|r| r.event_type == "m.room.encrypted".into()) .map(|r| r.message_count()) .sum(); - let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap(); + let outbound = + machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap(); assert_eq!(event_count, 1); assert!(!outbound.pending_requests().is_empty()); @@ -1079,9 +1081,11 @@ mod tests { let keys_claim = keys_claim_response(); let users = keys_claim.one_time_keys.keys().map(Deref::deref); - let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap(); + let outbound = + machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap(); let CollectRecipientsResult { should_rotate, .. } = machine + .inner .group_session_manager .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound) .await @@ -1095,6 +1099,7 @@ mod tests { }; let CollectRecipientsResult { should_rotate, .. } = machine + .inner .group_session_manager .collect_session_recipients(users.clone(), &settings, &outbound) .await @@ -1108,6 +1113,7 @@ mod tests { }; let CollectRecipientsResult { should_rotate, .. } = machine + .inner .group_session_manager .collect_session_recipients(users, &settings, &outbound) .await @@ -1127,6 +1133,7 @@ mod tests { let machine = machine_with_user(user_id, device_id).await; let (outbound, _) = machine + .inner .group_session_manager .get_or_create_outbound_session(room_id, EncryptionSettings::default()) .await @@ -1137,6 +1144,7 @@ mod tests { let users = [user_id].into_iter(); let CollectRecipientsResult { devices: recipients, .. } = machine + .inner .group_session_manager .collect_session_recipients(users, &settings, &outbound) .await @@ -1154,6 +1162,7 @@ mod tests { let users = [user_id].into_iter(); let CollectRecipientsResult { devices: recipients, .. } = machine + .inner .group_session_manager .collect_session_recipients(users, &settings, &outbound) .await @@ -1168,6 +1177,7 @@ mod tests { let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } = machine + .inner .group_session_manager .collect_session_recipients(users, &settings, &outbound) .await