crypto: Reduce stack size of OlmMachine

This commit is contained in:
Jonas Platte
2023-05-04 19:16:51 +02:00
committed by Jonas Platte
parent e5375a475e
commit 33c84b9ffd
2 changed files with 163 additions and 126 deletions

View File

@@ -94,6 +94,10 @@ use crate::{
/// Matrix end to end encryption.
#[derive(Clone)]
pub struct OlmMachine {
pub(crate) inner: Arc<OlmMachineInner>,
}
pub struct OlmMachineInner {
/// The unique user id that owns this account.
user_id: OwnedUserId,
/// The unique device ID of the device that holds this account.
@@ -131,8 +135,8 @@ pub struct OlmMachine {
impl std::fmt::Debug for OlmMachine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OlmMachine")
.field("user_id", &self.user_id)
.field("device_id", &self.device_id)
.field("user_id", &self.user_id())
.field("device_id", &self.device_id())
.finish()
}
}
@@ -195,7 +199,7 @@ impl OlmMachine {
#[cfg(feature = "backups_v1")]
let backup_machine = BackupMachine::new(account.clone(), store.clone(), None);
OlmMachine {
let inner = Arc::new(OlmMachineInner {
user_id,
device_id,
account,
@@ -208,7 +212,9 @@ impl OlmMachine {
identity_manager,
#[cfg(feature = "backups_v1")]
backup_machine,
}
});
Self { inner }
}
/// Create a new OlmMachine with the given [`CryptoStore`].
@@ -299,27 +305,27 @@ impl OlmMachine {
/// Get the crypto store associated with this `OlmMachine` instance.
pub fn store(&self) -> &Store {
&self.store
&self.inner.store
}
/// The unique user id that owns this `OlmMachine` instance.
pub fn user_id(&self) -> &UserId {
&self.user_id
&self.inner.user_id
}
/// The unique device ID that identifies this `OlmMachine`.
pub fn device_id(&self) -> &DeviceId {
&self.device_id
&self.inner.device_id
}
/// Get the public parts of our Olm identity keys.
pub fn identity_keys(&self) -> IdentityKeys {
self.account.identity_keys()
self.inner.account.identity_keys()
}
/// Get the display name of our own device
pub async fn display_name(&self) -> StoreResult<Option<String>> {
self.store.device_display_name().await
self.store().device_display_name().await
}
/// Get the list of "tracked users".
@@ -327,7 +333,7 @@ impl OlmMachine {
/// See [`update_tracked_users`](#method.update_tracked_users) for more
/// information.
pub async fn tracked_users(&self) -> StoreResult<HashSet<OwnedUserId>> {
self.store.tracked_users().await
self.store().tracked_users().await
}
/// Enable or disable room key forwarding.
@@ -337,12 +343,12 @@ impl OlmMachine {
/// events.
#[cfg(feature = "automatic-room-key-forwarding")]
pub fn toggle_room_key_forwarding(&self, enable: bool) {
self.key_request_machine.toggle_room_key_forwarding(enable)
self.inner.key_request_machine.toggle_room_key_forwarding(enable)
}
/// Is room key forwarding enabled?
pub fn is_room_key_forwarding_enabled(&self) -> bool {
self.key_request_machine.is_room_key_forwarding_enabled()
self.inner.key_request_machine.is_room_key_forwarding_enabled()
}
/// Get the outgoing requests that need to be sent out.
@@ -359,11 +365,12 @@ impl OlmMachine {
request_id: TransactionId::new(),
request: Arc::new(r.into()),
}) {
self.account.save().await?;
self.inner.account.save().await?;
requests.push(r);
}
for request in self
.inner
.identity_manager
.users_for_key_query()
.await?
@@ -373,8 +380,8 @@ impl OlmMachine {
requests.push(request);
}
requests.append(&mut self.verification_machine.outgoing_messages());
requests.append(&mut self.key_request_machine.outgoing_to_device_requests().await?);
requests.append(&mut self.inner.verification_machine.outgoing_messages());
requests.append(&mut self.inner.key_request_machine.outgoing_to_device_requests().await?);
Ok(requests)
}
@@ -410,14 +417,14 @@ impl OlmMachine {
self.receive_cross_signing_upload_response().await?;
}
IncomingResponse::SignatureUpload(_) => {
self.verification_machine.mark_request_as_sent(request_id);
self.inner.verification_machine.mark_request_as_sent(request_id);
}
IncomingResponse::RoomMessage(_) => {
self.verification_machine.mark_request_as_sent(request_id);
self.inner.verification_machine.mark_request_as_sent(request_id);
}
IncomingResponse::KeysBackup(_) => {
#[cfg(feature = "backups_v1")]
self.backup_machine.mark_request_as_sent(request_id).await?;
self.inner.backup_machine.mark_request_as_sent(request_id).await?;
}
};
@@ -426,12 +433,12 @@ impl OlmMachine {
/// Mark the cross signing identity as shared.
async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> {
let identity = self.user_identity.lock().await;
let identity = self.inner.user_identity.lock().await;
identity.mark_as_shared();
let changes = Changes { private_identity: Some(identity.clone()), ..Default::default() };
self.store.save_changes(changes).await
self.store().save_changes(changes).await
}
/// Create a new cross signing identity and get the upload request to push
@@ -454,11 +461,12 @@ impl OlmMachine {
&self,
reset: bool,
) -> StoreResult<(UploadSigningKeysRequest, UploadSignaturesRequest)> {
let mut identity = self.user_identity.lock().await;
let mut identity = self.inner.user_identity.lock().await;
if identity.is_empty().await || reset {
info!("Creating new cross signing identity");
let (id, request, signature_request) = self.account.bootstrap_cross_signing().await;
let (id, request, signature_request) =
self.inner.account.bootstrap_cross_signing().await;
*identity = id;
@@ -472,7 +480,7 @@ impl OlmMachine {
..Default::default()
};
self.store.save_changes(changes).await?;
self.store().save_changes(changes).await?;
Ok((request, signature_request))
} else {
@@ -480,7 +488,7 @@ impl OlmMachine {
let request = identity.as_upload_request().await;
// TODO remove this expect.
let signature_request =
identity.sign_account(&self.account).await.expect("Can't sign device keys");
identity.sign_account(&self.inner.account).await.expect("Can't sign device keys");
Ok((request, signature_request))
}
}
@@ -489,7 +497,7 @@ impl OlmMachine {
#[cfg(any(test, feature = "testing"))]
#[allow(dead_code)]
pub(crate) fn account(&self) -> &ReadOnlyAccount {
&self.account
&self.inner.account
}
/// Receive a successful keys upload response.
@@ -502,7 +510,7 @@ impl OlmMachine {
&self,
response: &upload_keys::v3::Response,
) -> OlmResult<()> {
self.account.receive_keys_upload_response(response).await
self.inner.account.receive_keys_upload_response(response).await
}
/// Get the a key claiming request for the user/device pairs that we are
@@ -536,7 +544,7 @@ impl OlmMachine {
&self,
users: impl Iterator<Item = &UserId>,
) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
self.session_manager.get_missing_sessions(users).await
self.inner.session_manager.get_missing_sessions(users).await
}
/// Receive a successful key claim response and create new Olm sessions with
@@ -546,7 +554,7 @@ impl OlmMachine {
///
/// * `response` - The response containing the claimed one-time keys.
async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> {
self.session_manager.receive_keys_claim_response(response).await
self.inner.session_manager.receive_keys_claim_response(response).await
}
/// Receive a successful keys query response.
@@ -563,7 +571,7 @@ impl OlmMachine {
request_id: &TransactionId,
response: &KeysQueryResponse,
) -> OlmResult<(DeviceChanges, IdentityChanges)> {
self.identity_manager.receive_keys_query_response(request_id, response).await
self.inner.identity_manager.receive_keys_query_response(request_id, response).await
}
/// Get a request to upload E2EE keys to the server.
@@ -576,7 +584,8 @@ impl OlmMachine {
/// [`receive_keys_upload_response`]: #method.receive_keys_upload_response
/// [`OlmMachine`]: struct.OlmMachine.html
async fn keys_for_upload(&self) -> Option<upload_keys::v3::Request> {
let (device_keys, one_time_keys, fallback_keys) = self.account.keys_for_upload().await;
let (device_keys, one_time_keys, fallback_keys) =
self.inner.account.keys_for_upload().await;
if device_keys.is_none() && one_time_keys.is_empty() && fallback_keys.is_empty() {
None
@@ -601,7 +610,7 @@ impl OlmMachine {
&self,
event: &EncryptedToDeviceEvent,
) -> OlmResult<OlmDecryptionInfo> {
let mut decrypted = self.account.decrypt_to_device_event(event).await?;
let mut decrypted = self.inner.account.decrypt_to_device_event(event).await?;
// Handle the decrypted event, e.g. fetch out Megolm sessions out of
// the event.
self.handle_decrypted_to_device_event(&mut decrypted).await?;
@@ -635,7 +644,7 @@ impl OlmMachine {
Ok(session) => {
tracing::Span::current().record("session_id", session.session_id());
if self.store.compare_group_session(&session).await? == SessionOrdering::Better {
if self.store().compare_group_session(&session).await? == SessionOrdering::Better {
info!("Received a new megolm room key");
Ok(Some(session))
@@ -699,11 +708,12 @@ impl OlmMachine {
room_id: &RoomId,
) -> OlmResult<()> {
let (_, session) = self
.inner
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default())
.await?;
self.store.save_inbound_group_sessions(&[session]).await?;
self.store().save_inbound_group_sessions(&[session]).await?;
Ok(())
}
@@ -715,6 +725,7 @@ impl OlmMachine {
room_id: &RoomId,
) -> OlmResult<InboundGroupSession> {
let (_, session) = self
.inner
.group_session_manager
.create_outbound_group_session(room_id, EncryptionSettings::default())
.await?;
@@ -773,7 +784,7 @@ impl OlmMachine {
content: Value,
event_type: &str,
) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
self.group_session_manager.encrypt(room_id, content, event_type).await
self.inner.group_session_manager.encrypt(room_id, content, event_type).await
}
/// Invalidate the currently active outbound group session for the given
@@ -782,7 +793,7 @@ impl OlmMachine {
/// Returns true if a session was invalidated, false if there was no session
/// to invalidate.
pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
self.group_session_manager.invalidate_group_session(room_id).await
self.inner.group_session_manager.invalidate_group_session(room_id).await
}
/// Get to-device requests to share a room key with users in a room.
@@ -802,7 +813,7 @@ impl OlmMachine {
users: impl Iterator<Item = &UserId>,
encryption_settings: impl Into<EncryptionSettings>,
) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
self.group_session_manager.share_room_key(room_id, users, encryption_settings).await
self.inner.group_session_manager.share_room_key(room_id, users, encryption_settings).await
}
/// Receive an unencrypted verification event.
@@ -817,7 +828,7 @@ impl OlmMachine {
&self,
event: &AnyMessageLikeEvent,
) -> StoreResult<()> {
self.verification_machine.receive_any_event(event).await
self.inner.verification_machine.receive_any_event(event).await
}
/// Receive a verification event.
@@ -825,7 +836,7 @@ impl OlmMachine {
/// in rooms to the `OlmMachine`. The event should be in the decrypted form.
/// in rooms to the `OlmMachine`.
pub async fn receive_verification_event(&self, event: &AnyMessageLikeEvent) -> StoreResult<()> {
self.verification_machine.receive_any_event(event).await
self.inner.verification_machine.receive_any_event(event).await
}
/// Receive and properly handle a decrypted to-device event.
@@ -853,6 +864,7 @@ impl OlmMachine {
}
AnyDecryptedOlmEvent::ForwardedRoomKey(e) => {
let session = self
.inner
.key_request_machine
.receive_forwarded_room_key(decrypted.result.sender_key, e)
.await?;
@@ -860,6 +872,7 @@ impl OlmMachine {
}
AnyDecryptedOlmEvent::SecretSend(e) => {
let name = self
.inner
.key_request_machine
.receive_secret_event(decrypted.result.sender_key, e)
.await?;
@@ -885,23 +898,23 @@ impl OlmMachine {
}
async fn handle_verification_event(&self, event: &ToDeviceEvents) {
if let Err(e) = self.verification_machine.receive_any_event(event).await {
if let Err(e) = self.inner.verification_machine.receive_any_event(event).await {
error!("Error handling a verification event: {e:?}");
}
}
/// Mark an outgoing to-device requests as sent.
async fn mark_to_device_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
self.verification_machine.mark_request_as_sent(request_id);
self.key_request_machine.mark_outgoing_request_as_sent(request_id).await?;
self.group_session_manager.mark_request_as_sent(request_id).await?;
self.session_manager.mark_outgoing_request_as_sent(request_id);
self.inner.verification_machine.mark_request_as_sent(request_id);
self.inner.key_request_machine.mark_outgoing_request_as_sent(request_id).await?;
self.inner.group_session_manager.mark_request_as_sent(request_id).await?;
self.inner.session_manager.mark_outgoing_request_as_sent(request_id);
Ok(())
}
/// Get a verification object for the given user id with the given flow id.
pub fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option<Verification> {
self.verification_machine.get_verification(user_id, flow_id)
self.inner.verification_machine.get_verification(user_id, flow_id)
}
/// Get a verification request object with the given flow id.
@@ -910,12 +923,12 @@ impl OlmMachine {
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
self.verification_machine.get_request(user_id, flow_id)
self.inner.verification_machine.get_request(user_id, flow_id)
}
/// Get all the verification requests of a given user.
pub fn get_verification_requests(&self, user_id: &UserId) -> Vec<VerificationRequest> {
self.verification_machine.get_requests(user_id)
self.inner.verification_machine.get_requests(user_id)
}
async fn update_key_counts(
@@ -923,15 +936,15 @@ impl OlmMachine {
one_time_key_count: &BTreeMap<DeviceKeyAlgorithm, UInt>,
unused_fallback_keys: Option<&[DeviceKeyAlgorithm]>,
) {
self.account.update_key_counts(one_time_key_count, unused_fallback_keys).await;
self.inner.account.update_key_counts(one_time_key_count, unused_fallback_keys).await;
}
async fn handle_to_device_event(&self, changes: &mut Changes, event: &ToDeviceEvents) {
use crate::types::events::ToDeviceEvents::*;
match event {
RoomKeyRequest(e) => self.key_request_machine.receive_incoming_key_request(e),
SecretRequest(e) => self.key_request_machine.receive_incoming_secret_request(e),
RoomKeyRequest(e) => self.inner.key_request_machine.receive_incoming_key_request(e),
SecretRequest(e) => self.inner.key_request_machine.receive_incoming_secret_request(e),
RoomKeyWithheld(e) => self.add_withheld_info(changes, e).await,
KeyVerificationAccept(..)
| KeyVerificationCancel(..)
@@ -999,8 +1012,11 @@ impl OlmMachine {
Ok(e) => e,
Err(err) => {
if let OlmError::SessionWedged(sender, curve_key) = err {
if let Err(e) =
self.session_manager.mark_device_as_wedged(&sender, curve_key).await
if let Err(e) = self
.inner
.session_manager
.mark_device_as_wedged(&sender, curve_key)
.await
{
error!(
error = ?e,
@@ -1017,7 +1033,7 @@ impl OlmMachine {
// one as well.
match decrypted.session {
SessionType::New(s) => {
changes.account = Some(self.account.inner.clone());
changes.account = Some((*self.inner.account).clone());
changes.sessions.push(s);
}
SessionType::Existing(s) => {
@@ -1081,16 +1097,17 @@ impl OlmMachine {
unused_fallback_keys: Option<&[DeviceKeyAlgorithm]>,
) -> OlmResult<Vec<Raw<AnyToDeviceEvent>>> {
// Remove verification objects that have expired or are done.
let mut events = self.verification_machine.garbage_collect();
let mut events = self.inner.verification_machine.garbage_collect();
// Always save the account, a new session might get created which also
// touches the account.
let mut changes =
Changes { account: Some(self.account.inner.clone()), ..Default::default() };
Changes { account: Some((*self.inner.account).clone()), ..Default::default() };
self.update_key_counts(one_time_keys_counts, unused_fallback_keys).await;
if let Err(e) = self
.inner
.identity_manager
.receive_device_changes(changed_devices.changed.iter().map(|u| u.as_ref()))
.await
@@ -1103,11 +1120,12 @@ impl OlmMachine {
events.push(raw_event);
}
let changed_sessions = self.key_request_machine.collect_incoming_key_requests().await?;
let changed_sessions =
self.inner.key_request_machine.collect_incoming_key_requests().await?;
changes.sessions.extend(changed_sessions);
self.store.save_changes(changes).await?;
self.store().save_changes(changes).await?;
Ok(events)
}
@@ -1134,7 +1152,7 @@ impl OlmMachine {
room_id: &RoomId,
) -> MegolmResult<(Option<OutgoingRequest>, OutgoingRequest)> {
let event = event.deserialize()?;
self.key_request_machine.request_key(room_id, &event).await
self.inner.key_request_machine.request_key(room_id, &event).await
}
async fn get_verification_state(
@@ -1243,7 +1261,7 @@ impl OlmMachine {
content: &SupportedEventEncryptionSchemes<'_>,
) -> MegolmResult<TimelineEvent> {
if let Some(session) =
self.store.get_inbound_group_session(room_id, content.session_id()).await?
self.store().get_inbound_group_session(room_id, content.session_id()).await?
{
tracing::Span::current().record("sender_key", session.sender_key().to_base64());
@@ -1262,6 +1280,7 @@ impl OlmMachine {
error
{
let withheld_code = self
.inner
.store
.get_withheld_info(room_id, content.session_id())
.await?
@@ -1280,6 +1299,7 @@ impl OlmMachine {
}
} else {
let withheld_code = self
.inner
.store
.get_withheld_info(room_id, content.session_id())
.await?
@@ -1329,7 +1349,10 @@ impl OlmMachine {
// Maybe for some code there is no point
MegolmError::MissingRoomKey(_)
| MegolmError::Decryption(DecryptionError::UnknownMessageIndex(_, _)) => {
self.key_request_machine.create_outgoing_key_request(room_id, &event).await?;
self.inner
.key_request_machine
.create_outgoing_key_request(room_id, &event)
.await?;
}
_ => {}
}
@@ -1362,12 +1385,12 @@ impl OlmMachine {
&self,
users: impl IntoIterator<Item = &UserId>,
) -> StoreResult<()> {
self.identity_manager.update_tracked_users(users).await
self.inner.identity_manager.update_tracked_users(users).await
}
async fn wait_if_user_pending(&self, user_id: &UserId, timeout: Option<Duration>) {
if let Some(timeout) = timeout {
self.store.wait_if_user_key_query_pending(timeout, user_id).await;
self.store().wait_if_user_key_query_pending(timeout, user_id).await;
}
}
@@ -1408,7 +1431,7 @@ impl OlmMachine {
timeout: Option<Duration>,
) -> StoreResult<Option<Device>> {
self.wait_if_user_pending(user_id, timeout).await;
self.store.get_device(user_id, device_id).await
self.store().get_device(user_id, device_id).await
}
/// Get the cross signing user identity of a user.
@@ -1430,7 +1453,7 @@ impl OlmMachine {
timeout: Option<Duration>,
) -> StoreResult<Option<UserIdentities>> {
self.wait_if_user_pending(user_id, timeout).await;
self.store.get_identity(user_id).await
self.store().get_identity(user_id).await
}
/// Get a map holding all the devices of an user.
@@ -1466,7 +1489,7 @@ impl OlmMachine {
timeout: Option<Duration>,
) -> StoreResult<UserDevices> {
self.wait_if_user_pending(user_id, timeout).await;
self.store.get_user_devices(user_id).await
self.store().get_user_devices(user_id).await
}
/// Import the given room keys into our store.
@@ -1525,6 +1548,7 @@ impl OlmMachine {
match InboundGroupSession::from_export(&key) {
Ok(session) => {
let old_session = self
.inner
.store
.get_inbound_group_session(session.room_id(), session.session_id())
.await?;
@@ -1564,7 +1588,7 @@ impl OlmMachine {
let changes = Changes { inbound_group_sessions: sessions, ..Default::default() };
self.store.save_changes(changes).await?;
self.store().save_changes(changes).await?;
info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys");
@@ -1606,7 +1630,7 @@ impl OlmMachine {
let mut exported = Vec::new();
let sessions: Vec<InboundGroupSession> = self
.store
.store()
.get_inbound_group_sessions()
.await?
.into_iter()
@@ -1626,7 +1650,7 @@ impl OlmMachine {
/// This can be used to check which private cross signing keys we have
/// stored locally.
pub async fn cross_signing_status(&self) -> CrossSigningStatus {
self.user_identity.lock().await.status().await
self.inner.user_identity.lock().await.status().await
}
/// Export all the private cross signing keys we have.
@@ -1637,11 +1661,11 @@ impl OlmMachine {
/// This method returns `None` if we don't have any private cross signing
/// keys.
pub async fn export_cross_signing_keys(&self) -> Option<CrossSigningKeyExport> {
let master_key = self.store.export_secret(&SecretName::CrossSigningMasterKey).await;
let master_key = self.store().export_secret(&SecretName::CrossSigningMasterKey).await;
let self_signing_key =
self.store.export_secret(&SecretName::CrossSigningSelfSigningKey).await;
self.store().export_secret(&SecretName::CrossSigningSelfSigningKey).await;
let user_signing_key =
self.store.export_secret(&SecretName::CrossSigningUserSigningKey).await;
self.store().export_secret(&SecretName::CrossSigningUserSigningKey).await;
if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
None
@@ -1658,14 +1682,14 @@ impl OlmMachine {
&self,
export: CrossSigningKeyExport,
) -> Result<CrossSigningStatus, SecretImportError> {
self.store.import_cross_signing_keys(export).await
self.store().import_cross_signing_keys(export).await
}
async fn sign_with_master_key(
&self,
message: &str,
) -> Result<(OwnedDeviceKeyId, Ed25519Signature), SignatureError> {
let identity = &*self.user_identity.lock().await;
let identity = &*self.inner.user_identity.lock().await;
let key_id = identity.master_key_id().await.ok_or(SignatureError::MissingSigningKey)?;
let signature = identity.sign(message).await?;
@@ -1681,8 +1705,8 @@ impl OlmMachine {
pub async fn sign(&self, message: &str) -> Signatures {
let mut signatures = Signatures::new();
let key_id = self.account.signing_key_id();
let signature = self.account.sign(message).await;
let key_id = self.inner.account.signing_key_id();
let signature = self.inner.account.sign(message).await;
signatures.add_signature(self.user_id().to_owned(), key_id, signature);
match self.sign_with_master_key(message).await {
@@ -1703,7 +1727,7 @@ impl OlmMachine {
/// the server.
#[cfg(feature = "backups_v1")]
pub fn backup_machine(&self) -> &BackupMachine {
&self.backup_machine
&self.inner.backup_machine
}
}
@@ -1833,7 +1857,7 @@ pub(crate) mod tests {
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(user_id(), bob_device_id()).await;
machine.account.inner.update_uploaded_key_count(0);
machine.account().update_uploaded_key_count(0);
let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
let response = keys_upload_response();
machine.receive_keys_upload_response(&response).await.unwrap();
@@ -1860,8 +1884,8 @@ pub(crate) mod tests {
let alice_device = ReadOnlyDevice::from_machine(&alice).await;
let bob_device = ReadOnlyDevice::from_machine(&bob).await;
alice.store.save_devices(&[bob_device]).await.unwrap();
bob.store.save_devices(&[alice_device]).await.unwrap();
alice.store().save_devices(&[bob_device]).await.unwrap();
bob.store().save_devices(&[alice_device]).await.unwrap();
(alice, bob, otk)
}
@@ -1890,19 +1914,19 @@ pub(crate) mod tests {
let (alice, bob) = get_machine_pair_with_session().await;
let bob_device =
alice.get_device(&bob.user_id, &bob.device_id, None).await.unwrap().unwrap();
alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
let (session, content) = bob_device
.encrypt("m.dummy", serde_json::to_value(ToDeviceDummyEventContent::new()).unwrap())
.await
.unwrap();
alice.store.save_sessions(&[session]).await.unwrap();
alice.store().save_sessions(&[session]).await.unwrap();
let event =
ToDeviceEvent::new(alice.user_id().to_owned(), content.deserialize_as().unwrap());
let decrypted = bob.decrypt_to_device_event(&event).await.unwrap();
bob.store.save_sessions(&[decrypted.session.session()]).await.unwrap();
bob.store().save_sessions(&[decrypted.session.session()]).await.unwrap();
(alice, bob)
}
@@ -1925,28 +1949,28 @@ pub(crate) mod tests {
async fn generate_one_time_keys() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
assert!(machine.account.generate_one_time_keys().await.is_some());
assert!(machine.account().generate_one_time_keys().await.is_some());
let mut response = keys_upload_response();
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.account.generate_one_time_keys().await.is_some());
assert!(machine.account().generate_one_time_keys().await.is_some());
response.one_time_key_counts.insert(DeviceKeyAlgorithm::SignedCurve25519, uint!(50));
machine.receive_keys_upload_response(&response).await.unwrap();
assert!(machine.account.generate_one_time_keys().await.is_none());
assert!(machine.account().generate_one_time_keys().await.is_none());
}
#[async_test]
async fn test_device_key_signing() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let device_keys = machine.account.device_keys().await;
let identity_keys = machine.account.identity_keys();
let device_keys = machine.account().device_keys().await;
let identity_keys = machine.account().identity_keys();
let ed25519_key = identity_keys.ed25519;
let ret = ed25519_key.verify_json(
&machine.user_id,
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&device_keys,
);
@@ -1959,11 +1983,12 @@ pub(crate) mod tests {
let room_id = room_id!("!test:example.org");
machine.create_outbound_group_session_with_defaults(room_id).await.unwrap();
assert!(machine.group_session_manager.get_outbound_group_session(room_id).is_some());
assert!(machine.inner.group_session_manager.get_outbound_group_session(room_id).is_some());
machine.invalidate_group_session(room_id).await.unwrap();
assert!(machine
.inner
.group_session_manager
.get_outbound_group_session(room_id)
.unwrap()
@@ -1974,12 +1999,12 @@ pub(crate) mod tests {
async fn test_invalid_signature() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let device_keys = machine.account.device_keys().await;
let device_keys = machine.account().device_keys().await;
let key = Ed25519PublicKey::from_slice(&[0u8; 32]).unwrap();
let ret = key.verify_json(
&machine.user_id,
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&device_keys,
);
@@ -1989,10 +2014,10 @@ pub(crate) mod tests {
#[async_test]
async fn one_time_key_signing() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
machine.account.inner.update_uploaded_key_count(49);
machine.account().update_uploaded_key_count(49);
let mut one_time_keys = machine.account.signed_one_time_keys().await;
let ed25519_key = machine.account.identity_keys().ed25519;
let mut one_time_keys = machine.account().signed_one_time_keys().await;
let ed25519_key = machine.account().identity_keys().ed25519;
let one_time_key: SignedKey = one_time_keys
.values_mut()
@@ -2003,7 +2028,7 @@ pub(crate) mod tests {
ed25519_key
.verify_json(
&machine.user_id,
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&one_time_key,
)
@@ -2013,9 +2038,9 @@ pub(crate) mod tests {
#[async_test]
async fn test_keys_for_upload() {
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
machine.account.inner.update_uploaded_key_count(0);
machine.account().update_uploaded_key_count(0);
let ed25519_key = machine.account.identity_keys().ed25519;
let ed25519_key = machine.account().identity_keys().ed25519;
let mut request =
machine.keys_for_upload().await.expect("Can't prepare initial key upload");
@@ -2029,7 +2054,7 @@ pub(crate) mod tests {
.unwrap();
let ret = ed25519_key.verify_json(
&machine.user_id,
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&one_time_key,
);
@@ -2038,7 +2063,7 @@ pub(crate) mod tests {
let device_keys: DeviceKeys = request.device_keys.unwrap().deserialize_as().unwrap();
let ret = ed25519_key.verify_json(
&machine.user_id,
machine.user_id(),
&DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, machine.device_id()),
&device_keys,
);
@@ -2063,13 +2088,13 @@ pub(crate) mod tests {
let alice_id = user_id!("@alice:example.org");
let alice_device_id: &DeviceId = device_id!("JLAFKJWSCS");
let alice_devices = machine.store.get_user_devices(alice_id).await.unwrap();
let alice_devices = machine.store().get_user_devices(alice_id).await.unwrap();
assert!(alice_devices.devices().peekable().peek().is_none());
let req_id = TransactionId::new();
machine.receive_keys_query_response(&req_id, &response).await.unwrap();
let device = machine.store.get_device(alice_id, alice_device_id).await.unwrap().unwrap();
let device = machine.store().get_device(alice_id, alice_device_id).await.unwrap().unwrap();
assert_eq!(device.user_id(), alice_id);
assert_eq!(device.device_id(), alice_device_id);
}
@@ -2119,8 +2144,8 @@ pub(crate) mod tests {
.await;
let session = alice_machine
.store
.get_sessions(&bob_machine.account.identity_keys().curve25519.to_base64())
.store()
.get_sessions(&bob_machine.account().identity_keys().curve25519.to_base64())
.await
.unwrap()
.unwrap();
@@ -2168,7 +2193,7 @@ pub(crate) mod tests {
// with the same timestamps, so we manually masage them to be 10s apart.
let session_id = {
let sessions = alice_machine
.store
.store()
.get_sessions(&bob_machine.identity_keys().curve25519.to_base64())
.await
.unwrap()
@@ -2207,7 +2232,7 @@ pub(crate) mod tests {
let (alice, bob) = get_machine_pair_with_session().await;
let bob_device =
alice.get_device(&bob.user_id, &bob.device_id, None).await.unwrap().unwrap();
alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
let event = ToDeviceEvent::new(
alice.user_id().to_owned(),
@@ -2254,7 +2279,7 @@ pub(crate) mod tests {
let event = json_convert(&event).unwrap();
let alice_session =
alice.group_session_manager.get_outbound_group_session(room_id).unwrap();
alice.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
let decrypted = bob
.receive_sync_changes(vec![event], &Default::default(), &Default::default(), None)
@@ -2271,7 +2296,7 @@ pub(crate) mod tests {
}
let session =
bob.store.get_inbound_group_session(room_id, alice_session.session_id()).await;
bob.store().get_inbound_group_session(room_id, alice_session.session_id()).await;
assert!(session.unwrap().is_some());
}
@@ -2295,7 +2320,7 @@ pub(crate) mod tests {
let group_session =
bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session.unwrap();
bob.store.save_inbound_group_sessions(&[group_session.clone()]).await.unwrap();
bob.store().save_inbound_group_sessions(&[group_session.clone()]).await.unwrap();
// when we decrypt the room key, the
// inbound_group_session_streamroom_keys_received_stream should tell us
@@ -2441,7 +2466,7 @@ pub(crate) mod tests {
let export = group_session.as_ref().unwrap().clone().export().await;
bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
bob.store().save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
let plaintext = "It is a secret to everybody";
@@ -2525,7 +2550,7 @@ pub(crate) mod tests {
// Simulate an imported session, to change verification state
let imported = InboundGroupSession::from_export(&export).unwrap();
bob.store.save_inbound_group_sessions(&[imported]).await.unwrap();
bob.store().save_inbound_group_sessions(&[imported]).await.unwrap();
let encryption_info =
bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap();
@@ -2751,7 +2776,7 @@ pub(crate) mod tests {
assert_eq!(device_change.new.len(), 2);
assert_eq!(identity_change.new.len(), 1);
//
let devices = bob.store.get_user_devices(other_user_id).await.unwrap();
let devices = bob.store().get_user_devices(other_user_id).await.unwrap();
assert_eq!(devices.devices().count(), 2);
let fake_room_id = room_id!("!roomid:example.com");
@@ -2760,7 +2785,7 @@ pub(crate) mod tests {
// We will use the export to create various inbounds with other claimed
// ownership
let id_keys = bob.identity_keys();
let fake_device_id = bob.device_id.clone();
let fake_device_id = bob.device_id().into();
let olm = OutboundGroupSession::new(
fake_device_id,
Arc::new(id_keys),
@@ -2816,7 +2841,7 @@ pub(crate) mod tests {
let bob_other_device = device_id!("OTHERBOB");
let bob_other_machine = OlmMachine::new(bob_id, bob_other_device).await;
let bob_other_device = ReadOnlyDevice::from_machine(&bob_other_machine).await;
bob.store.save_devices(&[bob_other_device]).await.unwrap();
bob.store().save_devices(&[bob_other_device]).await.unwrap();
bob.get_device(bob_id, device_id!("OTHERBOB"), None)
.await
.unwrap()
@@ -2855,7 +2880,7 @@ pub(crate) mod tests {
let group_session =
bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session;
bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
bob.store().save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap();
let room_event = json_convert(&room_event).unwrap();
@@ -2865,7 +2890,7 @@ pub(crate) mod tests {
if let vodozemac::megolm::DecryptionError::UnknownMessageIndex(_, _) = vodo_error {
// check that key has been requested
let outgoing_to_devices =
bob.key_request_machine.outgoing_to_device_requests().await.unwrap();
bob.inner.key_request_machine.outgoing_to_device_requests().await.unwrap();
assert_eq!(1, outgoing_to_devices.len());
} else {
panic!("Should be UnknownMessageIndex error ")
@@ -2903,6 +2928,7 @@ pub(crate) mod tests {
alice.handle_verification_event(&event).await;
let (event, request_id) = alice
.inner
.verification_machine
.outgoing_messages()
.first()
@@ -2912,6 +2938,7 @@ pub(crate) mod tests {
bob.handle_verification_event(&event).await;
let (event, request_id) = bob
.inner
.verification_machine
.outgoing_messages()
.first()
@@ -3026,11 +3053,11 @@ pub(crate) mod tests {
alice.handle_verification_event(&event).await;
// Alice sends a key
let msgs = alice.verification_machine.outgoing_messages();
let msgs = alice.inner.verification_machine.outgoing_messages();
assert!(msgs.len() == 1);
let msg = msgs.first().unwrap();
let event = outgoing_request_to_event(alice.user_id(), msg);
alice.verification_machine.mark_request_as_sent(&msg.request_id);
alice.inner.verification_machine.mark_request_as_sent(&msg.request_id);
// ----------------------------------------------------------------------------
// On Bob's device:
@@ -3039,11 +3066,11 @@ pub(crate) mod tests {
bob.handle_verification_event(&event).await;
// Now bob sends a key
let msgs = bob.verification_machine.outgoing_messages();
let msgs = bob.inner.verification_machine.outgoing_messages();
assert!(msgs.len() == 1);
let msg = msgs.first().unwrap();
let event = outgoing_request_to_event(bob.user_id(), msg);
bob.verification_machine.mark_request_as_sent(&msg.request_id);
bob.inner.verification_machine.mark_request_as_sent(&msg.request_id);
// ----------------------------------------------------------------------------
// On Alice's device:
@@ -3090,7 +3117,7 @@ pub(crate) mod tests {
bob.handle_verification_event(&event_mac).await;
// Bob verifies that the MAC is valid and also sends a "done" message.
let msgs = bob.verification_machine.outgoing_messages();
let msgs = bob.inner.verification_machine.outgoing_messages();
eprintln!("{msgs:?}");
assert!(msgs.len() == 1);
let event = msgs.first().map(|r| outgoing_request_to_event(bob.user_id(), r)).unwrap();
@@ -3174,7 +3201,7 @@ pub(crate) mod tests {
bob.receive_sync_changes(vec![event], &changed_devices, &key_counts, None).await.unwrap();
let session = bob.store.get_inbound_group_session(room_id, &session_id).await;
let session = bob.store().get_inbound_group_session(room_id, &session_id).await;
assert!(session.unwrap().is_none());
}
@@ -3184,10 +3211,10 @@ pub(crate) mod tests {
let room_id = room_id!("!test:localhost");
let (alice, _) = get_machine_pair_with_setup_sessions().await;
let device = ReadOnlyDevice::from_machine(&alice).await;
alice.store.save_devices(&[device]).await.unwrap();
alice.store().save_devices(&[device]).await.unwrap();
let (outbound, mut inbound) =
alice.account.create_group_session_pair(room_id, Default::default()).await.unwrap();
alice.account().create_group_session_pair(room_id, Default::default()).await.unwrap();
let fake_key = Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE")
.unwrap()
@@ -3197,7 +3224,7 @@ pub(crate) mod tests {
let content = json!({});
let content = outbound.encrypt(content, "m.dummy").await;
alice.store.save_inbound_group_sessions(&[inbound]).await.unwrap();
alice.store().save_inbound_group_sessions(&[inbound]).await.unwrap();
let event = json!({
"sender": alice.user_id(),

View File

@@ -926,7 +926,8 @@ mod tests {
let requests =
machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap();
let outbound =
machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
assert!(!outbound.pending_requests().is_empty());
assert!(!outbound.shared());
@@ -1066,7 +1067,8 @@ mod tests {
.filter(|r| r.event_type == "m.room.encrypted".into())
.map(|r| r.message_count())
.sum();
let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap();
let outbound =
machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
assert_eq!(event_count, 1);
assert!(!outbound.pending_requests().is_empty());
@@ -1079,9 +1081,11 @@ mod tests {
let keys_claim = keys_claim_response();
let users = keys_claim.one_time_keys.keys().map(Deref::deref);
let outbound = machine.group_session_manager.get_outbound_group_session(room_id).unwrap();
let outbound =
machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
let CollectRecipientsResult { should_rotate, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
.await
@@ -1095,6 +1099,7 @@ mod tests {
};
let CollectRecipientsResult { should_rotate, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users.clone(), &settings, &outbound)
.await
@@ -1108,6 +1113,7 @@ mod tests {
};
let CollectRecipientsResult { should_rotate, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
@@ -1127,6 +1133,7 @@ mod tests {
let machine = machine_with_user(user_id, device_id).await;
let (outbound, _) = machine
.inner
.group_session_manager
.get_or_create_outbound_session(room_id, EncryptionSettings::default())
.await
@@ -1137,6 +1144,7 @@ mod tests {
let users = [user_id].into_iter();
let CollectRecipientsResult { devices: recipients, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
@@ -1154,6 +1162,7 @@ mod tests {
let users = [user_id].into_iter();
let CollectRecipientsResult { devices: recipients, .. } = machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await
@@ -1168,6 +1177,7 @@ mod tests {
let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
machine
.inner
.group_session_manager
.collect_session_recipients(users, &settings, &outbound)
.await