From a311dcbd3e706e9686527b38b1c05db1646f12f7 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 17 Jan 2025 10:05:51 +0100 Subject: [PATCH 1/6] refact(mem_store): Replace unwrap with expect --- crates/matrix-sdk-crypto/src/store/memorystore.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 1e6b4ae7f..2a9ba3bcb 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -235,7 +235,8 @@ impl CryptoStore for MemoryStore { for identity in changes.identities.new.into_iter().chain(changes.identities.changed) { identities.insert( identity.user_id().to_owned(), - serde_json::to_string(&identity).unwrap(), + serde_json::to_string(&identity) + .expect("UserIdentityData should always serialize to json"), ); } } @@ -535,9 +536,13 @@ impl CryptoStore for MemoryStore { } async fn get_own_device(&self) -> Result { - let account = self.load_account().await?.unwrap(); + let account = self.load_account().await? + .expect("Failed to load account"); - Ok(self.devices.get(&account.user_id, &account.device_id).unwrap()) + Ok( + self.devices.get(&account.user_id, &account.device_id) + .expect("Invalid state: Should always have a own device") + ) } async fn get_user_identity(&self, user_id: &UserId) -> Result> { @@ -545,7 +550,8 @@ impl CryptoStore for MemoryStore { match serialized { None => Ok(None), Some(serialized) => { - let id: UserIdentityData = serde_json::from_str(serialized.as_str()).unwrap(); + let id: UserIdentityData = serde_json::from_str(serialized.as_str()) + .expect("Only valid serialized identity are saved"); Ok(Some(id)) } } From 9ff3761cac3c4410e05db9afafb28c16a5bf0577 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 17 Jan 2025 10:28:19 +0100 Subject: [PATCH 2/6] refact(mem_store): Serialize account and cache static account data --- .../src/store/memorystore.rs | 63 +++++++++++++------ 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 2a9ba3bcb..e58259d78 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -15,6 +15,7 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::Infallible, + sync::Arc, }; use async_trait::async_trait; @@ -37,7 +38,10 @@ use super::{ use crate::{ gossiping::{GossipRequest, GossippedSecret, SecretInfo}, identities::{DeviceData, UserIdentityData}, - olm::{OutboundGroupSession, PrivateCrossSigningIdentity, SenderDataType}, + olm::{ + OutboundGroupSession, PickledAccount, PrivateCrossSigningIdentity, SenderDataType, + StaticAccountData, + }, types::events::room_key_withheld::RoomKeyWithheldEvent, TrackedUser, }; @@ -70,7 +74,9 @@ impl BackupVersion { /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug)] pub struct MemoryStore { - account: StdRwLock>, + static_account: Arc>>, + + account: StdRwLock>, sessions: StdRwLock>>, inbound_group_sessions: GroupSessionStore, @@ -101,6 +107,7 @@ pub struct MemoryStore { impl Default for MemoryStore { fn default() -> Self { MemoryStore { + static_account: Default::default(), account: Default::default(), sessions: Default::default(), inbound_group_sessions: GroupSessionStore::new(), @@ -131,6 +138,10 @@ impl MemoryStore { Self::default() } + fn get_static_account(&self) -> Option { + self.static_account.read().clone() + } + pub(crate) fn save_devices(&self, devices: Vec) { for device in devices { let _ = self.devices.add(device); @@ -201,7 +212,21 @@ impl CryptoStore for MemoryStore { type Error = Infallible; async fn load_account(&self) -> Result> { - Ok(self.account.read().as_ref().map(|acc| acc.deep_clone())) + let pickled_account: Option = self.account.read().as_ref().map(|acc| { + serde_json::from_str(acc) + .expect("Deserialization failed: invalid pickled account JSON format") + }); + + if let Some(pickle) = pickled_account { + let account = + Account::from_pickle(pickle).expect("From pickle failed: invalid pickle format"); + + *self.static_account.write() = Some(account.static_data().clone()); + + Ok(Some(account)) + } else { + Ok(None) + } } async fn load_identity(&self) -> Result> { @@ -213,9 +238,17 @@ impl CryptoStore for MemoryStore { } async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> { - if let Some(account) = changes.account { - *self.account.write() = Some(account); - } + let pickled_account = if let Some(account) = changes.account { + *self.static_account.write() = Some(account.static_data().clone()); + Some(account.pickle()) + } else { + None + }; + + *self.account.write() = pickled_account.map(|pickle| { + serde_json::to_string(&pickle) + .expect("Serialization failed: invalid pickled account JSON format") + }); Ok(()) } @@ -536,13 +569,13 @@ impl CryptoStore for MemoryStore { } async fn get_own_device(&self) -> Result { - let account = self.load_account().await? - .expect("Failed to load account"); + let account = + self.get_static_account().expect("Expect account to exist when getting own device"); - Ok( - self.devices.get(&account.user_id, &account.device_id) - .expect("Invalid state: Should always have a own device") - ) + Ok(self + .devices + .get(&account.user_id, &account.device_id) + .expect("Invalid state: Should always have a own device")) } async fn get_user_identity(&self, user_id: &UserId) -> Result> { @@ -1160,12 +1193,6 @@ mod integration_tests { } } - impl MemoryStore { - fn get_static_account(&self) -> Option { - self.account.read().as_ref().map(|acc| acc.static_data().clone()) - } - } - /// Return a clone of the store for the test with the supplied name. Note: /// dropping this store won't destroy its data, since /// [PersistentMemoryStore] is a reference-counted smart pointer From fe85cddf880ffdb12ab10cca0f432a237369a048 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 17 Jan 2025 13:22:13 +0100 Subject: [PATCH 3/6] refact(mem_store): Serialize/Deserialise olm sessions --- .../src/store/memorystore.rs | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index e58259d78..2e73f797e 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -39,8 +39,8 @@ use crate::{ gossiping::{GossipRequest, GossippedSecret, SecretInfo}, identities::{DeviceData, UserIdentityData}, olm::{ - OutboundGroupSession, PickledAccount, PrivateCrossSigningIdentity, SenderDataType, - StaticAccountData, + OutboundGroupSession, PickledAccount, PickledSession, PrivateCrossSigningIdentity, + SenderDataType, StaticAccountData, }, types::events::room_key_withheld::RoomKeyWithheldEvent, TrackedUser, @@ -77,7 +77,8 @@ pub struct MemoryStore { static_account: Arc>>, account: StdRwLock>, - sessions: StdRwLock>>, + // Map of sender_key to map of session_id to serialized pickle + sessions: StdRwLock>>, inbound_group_sessions: GroupSessionStore, /// Map room id -> session id -> backup order number @@ -154,17 +155,17 @@ impl MemoryStore { } } - fn save_sessions(&self, sessions: Vec) { + fn save_sessions(&self, sessions: Vec<(String, PickledSession)>) { let mut session_store = self.sessions.write(); - for session in sessions { - let entry = session_store.entry(session.sender_key().to_base64()).or_default(); + for (session_id, pickle) in sessions { + let entry = session_store.entry(pickle.sender_key.to_base64()).or_default(); - if let Some(old_entry) = entry.iter_mut().find(|entry| &session == *entry) { - *old_entry = session; - } else { - entry.push(session); - } + // insert or replace if exists + entry.insert( + session_id, + serde_json::to_string(&pickle).expect("Failed to serialize olm session"), + ); } } @@ -254,7 +255,14 @@ impl CryptoStore for MemoryStore { } async fn save_changes(&self, changes: Changes) -> Result<()> { - self.save_sessions(changes.sessions); + let mut pickled_session: Vec<(String, PickledSession)> = Vec::new(); + for session in changes.sessions { + let session_id = session.session_id().to_owned(); + let pickle = session.pickle().await; + pickled_session.push((session_id.clone(), pickle)); + } + self.save_sessions(pickled_session); + self.save_inbound_group_sessions(changes.inbound_group_sessions, None).await?; self.save_outbound_group_sessions(changes.outbound_group_sessions); self.save_private_identity(changes.private_identity); @@ -372,7 +380,21 @@ impl CryptoStore for MemoryStore { } async fn get_sessions(&self, sender_key: &str) -> Result>> { - Ok(self.sessions.read().get(sender_key).cloned()) + let device_keys = self.get_own_device().await?.as_device_keys().clone(); + + if let Some(pickles) = self.sessions.read().get(sender_key) { + let mut sessions: Vec = Vec::new(); + for serialized_pickle in pickles.values() { + let pickle: PickledSession = serde_json::from_str(serialized_pickle.as_str()) + .expect("Pickle pickle deserialization should work"); + let session = Session::from_pickle(device_keys.clone(), pickle) + .expect("Expect from pickle to always work"); + sessions.push(session); + } + Ok(Some(sessions)) + } else { + Ok(None) + } } async fn get_inbound_group_session( @@ -695,18 +717,31 @@ mod tests { tests::get_account_and_session_test_helper, Account, InboundGroupSession, OlmMessageHash, PrivateCrossSigningIdentity, SenderData, }, - store::{memorystore::MemoryStore, Changes, CryptoStore, PendingChanges}, + store::{memorystore::MemoryStore, Changes, CryptoStore, DeviceChanges, PendingChanges}, + DeviceData, }; #[async_test] async fn test_session_store() { let (account, session) = get_account_and_session_test_helper(); + let own_device = DeviceData::from_account(&account); let store = MemoryStore::new(); assert!(store.load_account().await.unwrap().is_none()); + + store + .save_changes(Changes { + devices: DeviceChanges { new: vec![own_device], ..Default::default() }, + ..Default::default() + }) + .await + .unwrap(); store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); - store.save_sessions(vec![session.clone()]); + store + .save_changes(Changes { sessions: (vec![session.clone()]), ..Default::default() }) + .await + .unwrap(); let sessions = store.get_sessions(&session.sender_key.to_base64()).await.unwrap().unwrap(); From 182fc6fd8f784d524628d853cfdc8b61f2304625 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 17 Jan 2025 15:39:02 +0100 Subject: [PATCH 4/6] refact(mem_store): Ser/Deser inbound group sessions --- crates/matrix-sdk-crypto/src/store/caches.rs | 91 +------------------ .../src/store/memorystore.rs | 77 +++++++++++----- 2 files changed, 60 insertions(+), 108 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/store/caches.rs b/crates/matrix-sdk-crypto/src/store/caches.rs index ca3954edf..72cd40b09 100644 --- a/crates/matrix-sdk-crypto/src/store/caches.rs +++ b/crates/matrix-sdk-crypto/src/store/caches.rs @@ -27,15 +27,12 @@ use std::{ }; use matrix_sdk_common::locks::RwLock as StdRwLock; -use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId}; use serde::{Deserialize, Serialize}; use tokio::sync::{Mutex, RwLock}; use tracing::{field::display, instrument, trace, Span}; -use crate::{ - identities::DeviceData, - olm::{InboundGroupSession, Session}, -}; +use crate::{identities::DeviceData, olm::Session}; /// In-memory store for Olm Sessions. #[derive(Debug, Default, Clone)] @@ -86,52 +83,6 @@ impl SessionStore { } } -#[derive(Debug, Default)] -/// In-memory store that holds inbound group sessions. -pub struct GroupSessionStore { - entries: StdRwLock>>, -} - -impl GroupSessionStore { - /// Create a new empty store. - pub fn new() -> Self { - Self::default() - } - - /// Add an inbound group session to the store. - /// - /// Returns true if the session was added, false if the session was - /// already in the store. - pub fn add(&self, session: InboundGroupSession) -> bool { - self.entries - .write() - .entry(session.room_id().to_owned()) - .or_default() - .insert(session.session_id().to_owned(), session) - .is_none() - } - - /// Get all the group sessions the store knows about. - pub fn get_all(&self) -> Vec { - self.entries.read().values().flat_map(HashMap::values).cloned().collect() - } - - /// Get the number of `InboundGroupSession`s we have. - pub fn count(&self) -> usize { - self.entries.read().values().map(HashMap::len).sum() - } - - /// Get a inbound group session from our store. - /// - /// # Arguments - /// * `room_id` - The room id of the room that the session belongs to. - /// - /// * `session_id` - The unique id of the session. - pub fn get(&self, room_id: &RoomId, session_id: &str) -> Option { - self.entries.read().get(room_id)?.get(session_id).cloned() - } -} - /// In-memory store holding the devices of users. #[derive(Debug, Default)] pub struct DeviceStore { @@ -381,13 +332,10 @@ impl UsersForKeyQuery { mod tests { use matrix_sdk_test::async_test; use proptest::prelude::*; - use ruma::room_id; - use vodozemac::{Curve25519PublicKey, Ed25519PublicKey}; - use super::{DeviceStore, GroupSessionStore, SequenceNumber, SessionStore}; + use super::{DeviceStore, SequenceNumber, SessionStore}; use crate::{ - identities::device::testing::get_device, - olm::{tests::get_account_and_session_test_helper, InboundGroupSession, SenderData}, + identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper, }; #[async_test] @@ -422,37 +370,6 @@ mod tests { assert_eq!(&session, loaded_session); } - #[async_test] - async fn test_group_session_store() { - let (account, _) = get_account_and_session_test_helper(); - let room_id = room_id!("!test:localhost"); - let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw"; - - let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await; - - assert_eq!(0, outbound.message_index().await); - assert!(!outbound.shared()); - outbound.mark_as_shared(); - assert!(outbound.shared()); - - let inbound = InboundGroupSession::new( - Curve25519PublicKey::from_base64(curve_key).unwrap(), - Ed25519PublicKey::from_base64("ee3Ek+J2LkkPmjGPGLhMxiKnhiX//xcqaVL4RP6EypE").unwrap(), - room_id, - &outbound.session_key().await, - SenderData::unknown(), - outbound.settings().algorithm.to_owned(), - None, - ) - .unwrap(); - - let store = GroupSessionStore::new(); - store.add(inbound.clone()); - - let loaded_session = store.get(room_id, outbound.session_id()).unwrap(); - assert_eq!(inbound, loaded_session); - } - #[async_test] async fn test_device_store() { let device = get_device(); diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 2e73f797e..1cad81122 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -31,16 +31,15 @@ use tracing::warn; use vodozemac::Curve25519PublicKey; use super::{ - caches::{DeviceStore, GroupSessionStore}, - Account, BackupKeys, Changes, CryptoStore, DehydratedDeviceKey, InboundGroupSession, - PendingChanges, RoomKeyCounts, RoomSettings, Session, + caches::DeviceStore, Account, BackupKeys, Changes, CryptoStore, DehydratedDeviceKey, + InboundGroupSession, PendingChanges, RoomKeyCounts, RoomSettings, Session, }; use crate::{ gossiping::{GossipRequest, GossippedSecret, SecretInfo}, identities::{DeviceData, UserIdentityData}, olm::{ - OutboundGroupSession, PickledAccount, PickledSession, PrivateCrossSigningIdentity, - SenderDataType, StaticAccountData, + OutboundGroupSession, PickledAccount, PickledInboundGroupSession, PickledSession, + PrivateCrossSigningIdentity, SenderDataType, StaticAccountData, }, types::events::room_key_withheld::RoomKeyWithheldEvent, TrackedUser, @@ -79,7 +78,7 @@ pub struct MemoryStore { account: StdRwLock>, // Map of sender_key to map of session_id to serialized pickle sessions: StdRwLock>>, - inbound_group_sessions: GroupSessionStore, + inbound_group_sessions: StdRwLock>>, /// Map room id -> session id -> backup order number /// The latest backup in which this session is stored. Equivalent to @@ -111,7 +110,7 @@ impl Default for MemoryStore { static_account: Default::default(), account: Default::default(), sessions: Default::default(), - inbound_group_sessions: GroupSessionStore::new(), + inbound_group_sessions: Default::default(), inbound_group_sessions_backed_up_to: Default::default(), outbound_group_sessions: Default::default(), private_identity: Default::default(), @@ -351,9 +350,6 @@ impl CryptoStore for MemoryStore { sessions: Vec, backed_up_to_version: Option<&str>, ) -> Result<()> { - let mut inbound_group_sessions_backed_up_to = - self.inbound_group_sessions_backed_up_to.write(); - for session in sessions { let room_id = session.room_id(); let session_id = session.session_id(); @@ -369,12 +365,23 @@ impl CryptoStore for MemoryStore { } if let Some(backup_version) = backed_up_to_version { - inbound_group_sessions_backed_up_to + self.inbound_group_sessions_backed_up_to + .write() .entry(room_id.to_owned()) .or_default() .insert(session_id.to_owned(), BackupVersion::from(backup_version)); } - self.inbound_group_sessions.add(session); + + let pickle = session.pickle().await; + self.inbound_group_sessions + .write() + .entry(session.room_id().to_owned()) + .or_default() + .insert( + session.session_id().to_owned(), + serde_json::to_string(&pickle) + .expect("Pickle pickle data should serialize to json"), + ); } Ok(()) } @@ -402,7 +409,18 @@ impl CryptoStore for MemoryStore { room_id: &RoomId, session_id: &str, ) -> Result> { - Ok(self.inbound_group_sessions.get(room_id, session_id)) + let pickle: Option = self + .inbound_group_sessions + .read() + .get(room_id) + .and_then(|m| m.get(session_id)) + .and_then(|ser| { + serde_json::from_str(ser).expect("Pickle pickle deserialization should work") + }); + + Ok(pickle.map(|p| { + InboundGroupSession::from_pickle(p).expect("Expect from pickle to always work") + })) } async fn get_withheld_info( @@ -418,7 +436,18 @@ impl CryptoStore for MemoryStore { } async fn get_inbound_group_sessions(&self) -> Result> { - Ok(self.inbound_group_sessions.get_all()) + let inbounds = self + .inbound_group_sessions + .read() + .values() + .flat_map(HashMap::values) + .map(|ser| { + let pickle: PickledInboundGroupSession = + serde_json::from_str(ser).expect("Pickle deserialization should work"); + InboundGroupSession::from_pickle(pickle).expect("Expect from pickle to always work") + }) + .collect(); + Ok(inbounds) } async fn inbound_group_session_counts( @@ -439,7 +468,8 @@ impl CryptoStore for MemoryStore { 0 }; - Ok(RoomKeyCounts { total: self.inbound_group_sessions.count(), backed_up }) + let total = self.inbound_group_sessions.read().values().map(HashMap::len).sum(); + Ok(RoomKeyCounts { total, backed_up }) } async fn get_inbound_group_sessions_for_device_batch( @@ -511,21 +541,26 @@ impl CryptoStore for MemoryStore { backup_version: &str, room_and_session_ids: &[(&RoomId, &str)], ) -> Result<()> { - let mut inbound_group_sessions_backed_up_to = - self.inbound_group_sessions_backed_up_to.write(); - for &(room_id, session_id) in room_and_session_ids { - let session = self.inbound_group_sessions.get(room_id, session_id); + let session = self.get_inbound_group_session(room_id, session_id).await?; if let Some(session) = session { session.mark_as_backed_up(); - inbound_group_sessions_backed_up_to + self.inbound_group_sessions_backed_up_to + .write() .entry(room_id.to_owned()) .or_default() .insert(session_id.to_owned(), BackupVersion::from(backup_version)); - self.inbound_group_sessions.add(session); + // Save it back + let updated_pickle = session.pickle().await; + + self.inbound_group_sessions.write().entry(room_id.to_owned()).or_default().insert( + session_id.to_owned(), + serde_json::to_string(&updated_pickle) + .expect("Pickle serialization should work"), + ); } } From 50696a0d747653e17b01a8225ba629c8b49ff6ff Mon Sep 17 00:00:00 2001 From: Valere Date: Mon, 20 Jan 2025 16:45:37 +0100 Subject: [PATCH 5/6] refact(mem_store) Add a global save change lock similar to other stores --- crates/matrix-sdk-crypto/src/store/memorystore.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 1cad81122..6b9e81e5a 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -26,7 +26,7 @@ use ruma::{ events::secret::request::SecretName, time::Instant, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId, }; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, RwLock}; use tracing::warn; use vodozemac::Curve25519PublicKey; @@ -102,6 +102,7 @@ pub struct MemoryStore { dehydrated_device_pickle_key: RwLock>, next_batch_token: RwLock>, room_settings: StdRwLock>, + save_changes_lock: Arc>, } impl Default for MemoryStore { @@ -128,6 +129,7 @@ impl Default for MemoryStore { secret_inbox: Default::default(), next_batch_token: Default::default(), room_settings: Default::default(), + save_changes_lock: Default::default(), } } } @@ -238,6 +240,8 @@ impl CryptoStore for MemoryStore { } async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> { + let _guard = self.save_changes_lock.lock().await; + let pickled_account = if let Some(account) = changes.account { *self.static_account.write() = Some(account.static_data().clone()); Some(account.pickle()) @@ -254,6 +258,8 @@ impl CryptoStore for MemoryStore { } async fn save_changes(&self, changes: Changes) -> Result<()> { + let _guard = self.save_changes_lock.lock().await; + let mut pickled_session: Vec<(String, PickledSession)> = Vec::new(); for session in changes.sessions { let session_id = session.session_id().to_owned(); From 542d68dcda75c41c5c330d3fe272f45b37340b90 Mon Sep 17 00:00:00 2001 From: Valere Date: Tue, 21 Jan 2025 10:34:58 +0100 Subject: [PATCH 6/6] fix(test): Properly set up test crypto memory store by saving own device --- .../src/gossiping/machine.rs | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index 31eb23f23..0bb1204e8 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -1115,6 +1115,7 @@ mod tests { use crate::{ gossiping::KeyForwardDecision, olm::OutboundGroupSession, + store::{CryptoStore, DeviceChanges}, types::requests::AnyOutgoingRequest, types::{ events::{ @@ -1177,13 +1178,12 @@ mod tests { let user_id = user_id.to_owned(); let device_id = DeviceId::new(); - let account = Account::with_device_id(&user_id, &device_id); - let store = Arc::new(CryptoStoreWrapper::new(&user_id, &device_id, MemoryStore::new())); + let store = Arc::new(store_with_account_helper(&user_id, &device_id).await); + let static_data = store.load_account().await.unwrap().unwrap().static_data; let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let verification = - VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account.static_data().clone(), identity, store, verification); - store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); + VerificationMachine::new(static_data.clone(), identity.clone(), store.clone()); + let store = Store::new(static_data, identity, store, verification); let session_cache = GroupSessionCache::new(store.clone()); let identity_manager = IdentityManager::new(store.clone()); @@ -1191,6 +1191,28 @@ mod tests { GossipMachine::new(store, identity_manager, session_cache, Default::default()) } + #[cfg(feature = "automatic-room-key-forwarding")] + async fn store_with_account_helper( + user_id: &UserId, + device_id: &DeviceId, + ) -> CryptoStoreWrapper { + // Properly create the store by first saving the own device and then the account + // data. + let account = Account::with_device_id(user_id, device_id); + let device = DeviceData::from_account(&account); + device.set_trust_state(LocalTrust::Verified); + + let changes = Changes { + devices: DeviceChanges { new: vec![device], ..Default::default() }, + ..Default::default() + }; + let mem_store = MemoryStore::new(); + mem_store.save_changes(changes).await.unwrap(); + mem_store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); + + CryptoStoreWrapper::new(user_id, device_id, mem_store) + } + async fn get_machine_test_helper() -> GossipMachine { let user_id = alice_id().to_owned(); let account = Account::with_device_id(&user_id, alice_device_id());