crypto: make Account not clonable and the Account in the cache optional

This commit is contained in:
Benjamin Bouvier
2023-10-10 14:02:33 +02:00
parent 3570519c0c
commit 8bfcd20801
12 changed files with 156 additions and 101 deletions

View File

@@ -54,12 +54,13 @@ impl Drop for DehydratedDevices {
#[uniffi::export]
impl DehydratedDevices {
pub fn create(&self) -> Arc<DehydratedDevice> {
DehydratedDevice {
inner: ManuallyDrop::new(self.inner.create()),
pub fn create(&self) -> Result<Arc<DehydratedDevice>, DehydrationError> {
let inner = self.runtime.block_on(self.inner.create())?;
Ok(Arc::new(DehydratedDevice {
inner: ManuallyDrop::new(inner),
runtime: self.runtime.to_owned(),
}
.into()
}))
}
pub fn rehydrate(

View File

@@ -91,7 +91,7 @@ pub struct DehydratedDevices {
impl DehydratedDevices {
/// Create a new [`DehydratedDevice`] which can be uploaded to the server.
pub fn create(&self) -> DehydratedDevice {
pub async fn create(&self) -> Result<DehydratedDevice, DehydrationError> {
let user_id = self.inner.user_id();
let user_identity = self.inner.store().private_identity();
@@ -104,9 +104,11 @@ impl DehydratedDevices {
store.clone(),
);
let store = Store::new(account, user_identity, store, verification_machine);
let store =
Store::new(account.static_data().clone(), user_identity, store, verification_machine);
store.save_pending_changes(crate::store::PendingChanges { account: Some(account) }).await?;
DehydratedDevice { store }
Ok(DehydratedDevice { store })
}
/// Rehydrate the dehydrated device.
@@ -456,10 +458,10 @@ mod tests {
}
#[async_test]
async fn dehydrated_device_creation() {
async fn test_dehydrated_device_creation() {
let olm_machine = get_olm_machine().await;
let dehydrated_device = olm_machine.dehydrated_devices().create();
let dehydrated_device = olm_machine.dehydrated_devices().create().await.unwrap();
let request = dehydrated_device
.keys_for_upload("Foo".to_owned(), PICKLE_KEY)
@@ -482,7 +484,7 @@ mod tests {
let room_id = room_id!("!test:example.org");
let alice = get_olm_machine().await;
let dehydrated_device = alice.dehydrated_devices().create();
let dehydrated_device = alice.dehydrated_devices().create().await.unwrap();
let mut request = dehydrated_device
.keys_for_upload("Foo".to_owned(), PICKLE_KEY)

View File

@@ -1069,7 +1069,7 @@ mod tests {
identities::{LocalTrust, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity},
session_manager::GroupSessionCache,
store::{CryptoStoreWrapper, MemoryStore, Store},
store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store},
types::events::room::encrypted::{EncryptedEvent, RoomEncryptedEventContent},
verification::VerificationMachine,
};
@@ -1113,7 +1113,7 @@ mod tests {
}
#[cfg(feature = "automatic-room-key-forwarding")]
fn test_gossip_machine(user_id: &UserId) -> GossipMachine {
async fn gossip_machine_test_helper(user_id: &UserId) -> GossipMachine {
let user_id = user_id.to_owned();
let device_id = DeviceId::new();
@@ -1122,13 +1122,15 @@ mod tests {
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
let store = Store::new(account, identity, store, verification);
let store = Store::new(account.static_data().clone(), identity, store, verification);
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
let session_cache = GroupSessionCache::new(store.clone());
GossipMachine::new(store, session_cache, Default::default())
}
async fn get_machine() -> GossipMachine {
async fn get_machine_test_helper() -> GossipMachine {
let user_id = alice_id().to_owned();
let account = Account::with_device_id(&user_id, alice_device_id());
let device = ReadOnlyDevice::from_account(&account).await;
@@ -1141,8 +1143,9 @@ mod tests {
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
let store = Store::new(account, identity, store, verification);
let store = Store::new(account.static_data().clone(), identity, store, verification);
store.save_devices(&[device, another_device]).await.unwrap();
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
let session_cache = GroupSessionCache::new(store.clone());
GossipMachine::new(store, session_cache, Default::default())
@@ -1154,16 +1157,16 @@ mod tests {
create_sessions: bool,
algorithm: EventEncryptionAlgorithm,
) -> (GossipMachine, OutboundGroupSession, GossipMachine) {
let alice_machine = get_machine().await;
let alice_machine = get_machine_test_helper().await;
let alice_device = ReadOnlyDevice::from_account(
alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
&alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
)
.await;
let bob_machine = test_gossip_machine(other_machine_owner);
let bob_machine = gossip_machine_test_helper(other_machine_owner).await;
let bob_device = ReadOnlyDevice::from_account(
bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
&*bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
)
.await;
@@ -1295,14 +1298,14 @@ mod tests {
#[async_test]
async fn create_machine() {
let machine = get_machine().await;
let machine = get_machine_test_helper().await;
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
}
#[async_test]
async fn re_request_keys() {
let machine = get_machine().await;
let machine = get_machine_test_helper().await;
let account = account();
let (outbound, session) = account.create_group_session_pair_with_defaults(room_id()).await;
@@ -1325,7 +1328,7 @@ mod tests {
#[async_test]
#[cfg(feature = "automatic-room-key-forwarding")]
async fn create_key_request() {
let machine = get_machine().await;
let machine = get_machine_test_helper().await;
let account = account();
let second_account = alice_2_account();
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
@@ -1357,7 +1360,7 @@ mod tests {
#[async_test]
#[cfg(feature = "automatic-room-key-forwarding")]
async fn receive_forwarded_key() {
let machine = get_machine().await;
let machine = get_machine_test_helper().await;
let account = account();
let second_account = alice_2_account();
@@ -1466,7 +1469,7 @@ mod tests {
#[async_test]
#[cfg(feature = "automatic-room-key-forwarding")]
async fn should_share_key_test() {
let machine = get_machine().await;
let machine = get_machine_test_helper().await;
let account = account();
let own_device =
@@ -1733,7 +1736,7 @@ mod tests {
#[async_test]
async fn test_secret_share_cycle() {
let alice_machine = get_machine().await;
let alice_machine = get_machine_test_helper().await;
let second_account = alice_2_account();
let alice_device = ReadOnlyDevice::from_account(&second_account).await;

View File

@@ -869,7 +869,7 @@ impl ReadOnlyDevice {
pub async fn from_machine_test_helper(
machine: &OlmMachine,
) -> Result<ReadOnlyDevice, crate::CryptoStoreError> {
Ok(ReadOnlyDevice::from_account(machine.store().cache().await?.account().await?).await)
Ok(ReadOnlyDevice::from_account(&*machine.store().cache().await?.account().await?).await)
}
/// Create a `ReadOnlyDevice` from an `Account`

View File

@@ -884,7 +884,7 @@ pub(crate) mod testing {
identities::IdentityManager,
machine::testing::response_from_file,
olm::{Account, PrivateCrossSigningIdentity},
store::{CryptoStoreWrapper, MemoryStore, Store},
store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store},
types::DeviceKeys,
verification::VerificationMachine,
UploadSigningKeysRequest,
@@ -902,15 +902,20 @@ pub(crate) mod testing {
device_id!("WSKKLTJZCL")
}
pub(crate) async fn manager(user_id: &UserId, device_id: &DeviceId) -> IdentityManager {
pub(crate) async fn manager_test_helper(
user_id: &UserId,
device_id: &DeviceId,
) -> IdentityManager {
let identity = PrivateCrossSigningIdentity::new(user_id.into()).await;
let identity = Arc::new(Mutex::new(identity));
let user_id = user_id.to_owned();
let account = Account::with_device_id(&user_id, device_id);
let static_account = account.static_data().clone();
let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new()));
let verification =
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
let store = Store::new(account, identity, store, verification);
VerificationMachine::new(static_account.clone(), identity.clone(), store.clone());
let store = Store::new(static_account, identity, store, verification);
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
IdentityManager::new(store)
}
@@ -1186,7 +1191,9 @@ pub(crate) mod tests {
use serde_json::json;
use stream_assert::{assert_closed, assert_pending, assert_ready};
use super::testing::{device_id, key_query, manager, other_key_query, other_user_id, user_id};
use super::testing::{
device_id, key_query, manager_test_helper, other_key_query, other_user_id, user_id,
};
use crate::{
identities::manager::testing::{other_key_query_cross_signed, own_key_query},
olm::PrivateCrossSigningIdentity,
@@ -1211,7 +1218,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_tracked_users() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let alice = user_id!("@alice:example.org");
assert!(
@@ -1231,13 +1238,13 @@ pub(crate) mod tests {
#[async_test]
async fn test_manager_creation() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
assert!(manager.store.tracked_users().await.unwrap().is_empty())
}
#[async_test]
async fn test_manager_key_query_response() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let other_user = other_user_id();
let devices = manager.store.get_user_devices(other_user).await.unwrap();
assert_eq!(devices.devices().count(), 0);
@@ -1264,7 +1271,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_manager_own_key_query_response() -> anyhow::Result<()> {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let our_user = user_id();
let devices = manager.store.get_user_devices(our_user).await?;
assert_eq!(devices.devices().count(), 0);
@@ -1299,7 +1306,7 @@ pub(crate) mod tests {
#[async_test]
async fn private_identity_invalidation_after_public_keys_change() {
let user_id = user_id!("@example1:localhost");
let manager = manager(user_id, "DEVICEID".into()).await;
let manager = manager_test_helper(user_id, "DEVICEID".into()).await;
let identity_request = {
let private_identity = manager.store.private_identity();
@@ -1387,7 +1394,7 @@ pub(crate) mod tests {
#[async_test]
async fn no_tracked_users_key_query_request() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
assert!(
manager.store.tracked_users().await.unwrap().is_empty(),
@@ -1407,8 +1414,8 @@ pub(crate) mod tests {
/// user is not removed from the list of outdated users when the
/// response is received
#[async_test]
async fn invalidation_race_handling() {
let manager = manager(user_id(), device_id()).await;
async fn test_invalidation_race_handling() {
let manager = manager_test_helper(user_id(), device_id()).await;
let alice = other_user_id();
manager.update_tracked_users([alice]).await.unwrap();
@@ -1436,7 +1443,7 @@ pub(crate) mod tests {
#[async_test]
async fn failure_handling() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let alice = user_id!("@alice:example.org");
assert!(
@@ -1476,7 +1483,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_out_of_band_key_query() {
// build the request
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let (reqid, req) = manager.build_key_query_for_users(vec![user_id()]);
assert!(req.device_keys.contains_key(user_id()));
@@ -1495,7 +1502,7 @@ pub(crate) mod tests {
#[async_test]
async fn devices_stream() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]);
let stream = manager.store.devices_stream();
@@ -1509,7 +1516,7 @@ pub(crate) mod tests {
#[async_test]
async fn identities_stream() {
let manager = manager(user_id(), device_id()).await;
let manager = manager_test_helper(user_id(), device_id()).await;
let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]);
let stream = manager.store.user_identities_stream();
@@ -1523,7 +1530,7 @@ pub(crate) mod tests {
#[async_test]
async fn identities_stream_raw() {
let mut manager = Some(manager(user_id(), device_id()).await);
let mut manager = Some(manager_test_helper(user_id(), device_id()).await);
let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![user_id()]);
let stream = manager.as_ref().unwrap().store.identities_stream_raw();
@@ -1568,7 +1575,7 @@ pub(crate) mod tests {
#[async_test]
async fn identities_stream_raw_signature_update() {
let mut manager = Some(manager(user_id(), device_id()).await);
let mut manager = Some(manager_test_helper(user_id(), device_id()).await);
let (request_id, _) =
manager.as_ref().unwrap().build_key_query_for_users(vec![other_user_id()]);

View File

@@ -147,7 +147,8 @@ impl OwnUserIdentity {
}
let cache = self.store.cache().await?;
cache.account().await?.sign_master_key(self.master_key.clone()).await
let account = cache.account().await?;
account.sign_master_key(self.master_key.clone()).await
}
/// Send a verification request to our other devices.

View File

@@ -64,6 +64,7 @@ use crate::{
olm::{
Account, CrossSigningStatus, EncryptionSettings, ExportedRoomKey, IdentityKeys,
InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, SessionType,
StaticAccountData,
},
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager},
@@ -168,25 +169,23 @@ impl OlmMachine {
) -> Result<OlmMachine, DehydrationError> {
let account =
Account::rehydrate(pickle_key, self.user_id(), device_id, device_data).await?;
let static_account = account.static_data().clone();
let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new()));
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
Ok(Self::new_helper(device_id, store, account, self.store().private_identity()))
Ok(Self::new_helper(device_id, store, static_account, self.store().private_identity()))
}
fn new_helper(
device_id: &DeviceId,
store: Arc<CryptoStoreWrapper>,
account: Account,
account: StaticAccountData,
user_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
) -> Self {
let verification_machine = VerificationMachine::new(
account.static_data().clone(),
user_identity.clone(),
store.clone(),
);
let store =
Store::new(account.clone(), user_identity.clone(), store, verification_machine.clone());
let verification_machine =
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
let store = Store::new(account, user_identity.clone(), store, verification_machine.clone());
let group_session_manager = GroupSessionManager::new(store.clone());
@@ -248,7 +247,8 @@ impl OlmMachine {
store: impl IntoCryptoStore,
) -> StoreResult<Self> {
let store = store.into_crypto_store();
let account = match store.load_account().await? {
let static_account = match store.load_account().await? {
Some(account) => {
if user_id != account.user_id() || device_id != account.device_id() {
return Err(CryptoStoreError::MismatchedAccount {
@@ -262,10 +262,12 @@ impl OlmMachine {
.record("curve25519_key", display(account.identity_keys().curve25519));
debug!("Restored an Olm account");
account
account.static_data().clone()
}
None => {
let account = Account::with_device_id(user_id, device_id);
let static_account = account.static_data().clone();
Span::current()
.record("ed25519_key", display(account.identity_keys().ed25519))
@@ -283,13 +285,11 @@ impl OlmMachine {
..Default::default()
};
store.save_changes(changes).await?;
store
.save_pending_changes(PendingChanges { account: Some(account.clone()) })
.await?;
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
debug!("Created a new Olm account");
account
static_account
}
};
@@ -310,7 +310,7 @@ impl OlmMachine {
let identity = Arc::new(Mutex::new(identity));
let store = Arc::new(CryptoStoreWrapper::new(user_id, store));
Ok(OlmMachine::new_helper(device_id, store, account, identity))
Ok(OlmMachine::new_helper(device_id, store, static_account, identity))
}
/// Get the crypto store associated with this `OlmMachine` instance.
@@ -2028,7 +2028,8 @@ impl OlmMachine {
#[cfg(any(feature = "testing", test))]
pub async fn uploaded_key_count(&self) -> Result<u64, CryptoStoreError> {
let cache = self.inner.store.cache().await?;
Ok(cache.account().await?.uploaded_key_count())
let account = cache.account().await?;
Ok(account.uploaded_key_count())
}
}
@@ -2436,6 +2437,7 @@ pub(crate) mod tests {
let mut request =
machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload");
drop(account); // Release the account lock.
let one_time_key: SignedKey = request
.one_time_keys
@@ -2461,16 +2463,27 @@ pub(crate) mod tests {
);
ret.unwrap();
let mut response = keys_upload_response();
response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519,
(account.max_one_time_keys().await).try_into().unwrap(),
);
let response = {
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
let mut response = keys_upload_response();
response.one_time_key_counts.insert(
DeviceKeyAlgorithm::SignedCurve25519,
(account.max_one_time_keys().await).try_into().unwrap(),
);
response
};
machine.receive_keys_upload_response(&response).await.unwrap();
let ret = machine.keys_for_upload(&account).await;
assert!(ret.is_none());
{
let cache = machine.store().cache().await.unwrap();
let account = cache.account().await.unwrap();
let ret = machine.keys_for_upload(&account).await;
assert!(ret.is_none());
}
}
#[async_test]

View File

@@ -314,7 +314,6 @@ impl StaticAccountData {
///
/// An account is the central identity for encrypted communication between two
/// devices.
#[derive(Clone)]
pub struct Account {
pub(crate) static_data: StaticAccountData,
/// `vodozemac` account.
@@ -1356,6 +1355,18 @@ impl Account {
})
}
}
/// Test-only.
#[cfg(any(test, feature = "testing"))]
#[doc(hidden)]
pub fn clone_for_testing(&self) -> Self {
Self {
static_data: self.static_data.clone(),
inner: self.inner.clone(),
shared: self.shared.clone(),
uploaded_signed_key_count: self.uploaded_signed_key_count.clone(),
}
}
}
impl PartialEq for Account {

View File

@@ -484,7 +484,7 @@ mod tests {
identities::{IdentityManager, ReadOnlyDevice},
olm::{Account, PrivateCrossSigningIdentity},
session_manager::GroupSessionCache,
store::{CryptoStoreWrapper, MemoryStore, Store},
store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store},
verification::VerificationMachine,
};
@@ -531,7 +531,6 @@ mod tests {
let user_id = user_id();
let device_id = device_id();
let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
let account = Account::with_device_id(user_id, device_id);
let store = Arc::new(CryptoStoreWrapper::new(user_id, MemoryStore::new()));
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
@@ -541,15 +540,12 @@ mod tests {
store.clone(),
);
let store = Store::new(account.clone(), identity, store, verification);
{
// Perform a dummy transaction to sync in-memory cache with the db.
let tr = store.transaction().await.unwrap();
tr.commit().await.unwrap();
}
let store = Store::new(account.static_data().clone(), identity, store, verification);
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
let session_cache = GroupSessionCache::new(store.clone());
let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
let key_request =
GossipMachine::new(store.clone(), session_cache, users_for_key_claim.clone());
@@ -656,7 +652,7 @@ mod tests {
let (_, mut session) = {
let cache = manager.store.cache().await.unwrap();
let manager_account = cache.account().await.unwrap();
let manager_account = &*cache.account().await.unwrap();
bob.create_session_for(manager_account).await
};

View File

@@ -66,7 +66,7 @@ macro_rules! cryptostore_integration_tests {
let store = get_store(name, None).await;
let account = get_account();
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
(account, store)
}
@@ -124,7 +124,7 @@ macro_rules! cryptostore_integration_tests {
let store = get_store("load_account", None).await;
let account = get_account();
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@@ -138,7 +138,7 @@ macro_rules! cryptostore_integration_tests {
get_store("load_account_with_passphrase", Some("secret_passphrase")).await;
let account = get_account();
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@@ -151,12 +151,12 @@ macro_rules! cryptostore_integration_tests {
let store = get_store("save_and_share_account", None).await;
let account = get_account();
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
account.mark_as_shared();
account.update_uploaded_key_count(50);
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
let loaded_account = store.load_account().await.expect("Can't load account");
let loaded_account = loaded_account.unwrap();
@@ -169,7 +169,7 @@ macro_rules! cryptostore_integration_tests {
async fn load_sessions() {
let store = get_store("load_sessions", None).await;
let (account, session) = get_account_and_session().await;
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
@@ -193,7 +193,7 @@ macro_rules! cryptostore_integration_tests {
let sender_key = session.sender_key.to_base64();
let session_id = session.session_id().to_owned();
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
store.save_changes(changes).await.unwrap();
@@ -491,7 +491,7 @@ macro_rules! cryptostore_integration_tests {
let account = Account::with_device_id(&user_id, device_id);
store.save_pending_changes(PendingChanges { account: Some(account.clone()), })
store.save_pending_changes(PendingChanges { account: Some(account), })
.await
.expect("Can't save account");

View File

@@ -52,6 +52,7 @@ fn encode_key_info(info: &SecretInfo) -> String {
/// An in-memory only store that will forget all the E2EE key once it's dropped.
#[derive(Debug)]
pub struct MemoryStore {
account: StdRwLock<Option<Account>>,
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
@@ -70,6 +71,7 @@ pub struct MemoryStore {
impl Default for MemoryStore {
fn default() -> Self {
MemoryStore {
account: Default::default(),
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
olm_hashes: Default::default(),
@@ -126,7 +128,7 @@ impl CryptoStore for MemoryStore {
type Error = Infallible;
async fn load_account(&self) -> Result<Option<Account>> {
Ok(None)
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.clone_for_testing()))
}
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
@@ -137,8 +139,11 @@ impl CryptoStore for MemoryStore {
Ok(self.next_batch_token.read().await.clone())
}
async fn save_pending_changes(&self, _changes: PendingChanges) -> Result<()> {
// TODO(bnjbvr) why didn't save_changes save the account?
async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
if let Some(account) = changes.account {
*self.account.write().unwrap() = Some(account);
}
Ok(())
}

View File

@@ -55,7 +55,7 @@ use ruma::{
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use tokio::sync::{Mutex, MutexGuard, RwLock};
use tracing::{info, warn};
use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
use zeroize::Zeroize;
@@ -107,9 +107,11 @@ pub struct Store {
#[derive(Debug)]
pub(crate) struct StoreCache {
store: Arc<CryptoStoreWrapper>,
tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
tracked_user_loading_lock: RwLock<bool>,
account: Account,
account: Mutex<Option<Account>>,
}
impl StoreCache {
@@ -119,8 +121,19 @@ impl StoreCache {
///
/// Note there should always be an account stored at least in the store, so this doesn't return
/// an `Option`.
pub async fn account(&self) -> Result<&Account> {
Ok(&self.account)
pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
let mut guard = self.account.lock().await;
if guard.is_some() {
Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
} else {
match self.store.load_account().await? {
Some(account) => {
*guard = Some(account);
Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
}
None => Err(CryptoStoreError::AccountUnset),
}
}
}
}
@@ -163,7 +176,9 @@ impl StoreTransaction {
/// Gets a `Account` for update.
pub async fn account(&mut self) -> Result<&mut Account> {
if self.changes.account.is_none() {
self.changes.account = Some(self.cache.account.clone());
// Make sure the cache loaded the account.
let _ = self.cache.account().await?;
self.changes.account = self.cache.account.lock().await.take();
}
Ok(self.changes.account.as_mut().unwrap())
}
@@ -604,23 +619,24 @@ impl From<&InboundGroupSession> for RoomKeyInfo {
impl Store {
/// Create a new Store.
pub(crate) fn new(
account: Account,
account: StaticAccountData,
identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
store: Arc<CryptoStoreWrapper>,
verification_machine: VerificationMachine,
) -> Self {
Self {
inner: Arc::new(StoreInner {
static_account: account.static_data().clone(),
static_account: account,
identity,
store,
store: store.clone(),
verification_machine,
users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()),
users_for_key_query_condvar: Condvar::new(),
cache: Arc::new(StoreCache {
store,
tracked_users: Default::default(),
tracked_user_loading_lock: Default::default(),
account,
account: Default::default(),
}),
}),
}