mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-04-24 09:08:23 -04:00
crypto: make Account not clonable and the Account in the cache optional
This commit is contained in:
@@ -54,12 +54,13 @@ impl Drop for DehydratedDevices {
|
||||
|
||||
#[uniffi::export]
|
||||
impl DehydratedDevices {
|
||||
pub fn create(&self) -> Arc<DehydratedDevice> {
|
||||
DehydratedDevice {
|
||||
inner: ManuallyDrop::new(self.inner.create()),
|
||||
pub fn create(&self) -> Result<Arc<DehydratedDevice>, DehydrationError> {
|
||||
let inner = self.runtime.block_on(self.inner.create())?;
|
||||
|
||||
Ok(Arc::new(DehydratedDevice {
|
||||
inner: ManuallyDrop::new(inner),
|
||||
runtime: self.runtime.to_owned(),
|
||||
}
|
||||
.into()
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn rehydrate(
|
||||
|
||||
@@ -91,7 +91,7 @@ pub struct DehydratedDevices {
|
||||
|
||||
impl DehydratedDevices {
|
||||
/// Create a new [`DehydratedDevice`] which can be uploaded to the server.
|
||||
pub fn create(&self) -> DehydratedDevice {
|
||||
pub async fn create(&self) -> Result<DehydratedDevice, DehydrationError> {
|
||||
let user_id = self.inner.user_id();
|
||||
let user_identity = self.inner.store().private_identity();
|
||||
|
||||
@@ -104,9 +104,11 @@ impl DehydratedDevices {
|
||||
store.clone(),
|
||||
);
|
||||
|
||||
let store = Store::new(account, user_identity, store, verification_machine);
|
||||
let store =
|
||||
Store::new(account.static_data().clone(), user_identity, store, verification_machine);
|
||||
store.save_pending_changes(crate::store::PendingChanges { account: Some(account) }).await?;
|
||||
|
||||
DehydratedDevice { store }
|
||||
Ok(DehydratedDevice { store })
|
||||
}
|
||||
|
||||
/// Rehydrate the dehydrated device.
|
||||
@@ -456,10 +458,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn dehydrated_device_creation() {
|
||||
async fn test_dehydrated_device_creation() {
|
||||
let olm_machine = get_olm_machine().await;
|
||||
|
||||
let dehydrated_device = olm_machine.dehydrated_devices().create();
|
||||
let dehydrated_device = olm_machine.dehydrated_devices().create().await.unwrap();
|
||||
|
||||
let request = dehydrated_device
|
||||
.keys_for_upload("Foo".to_owned(), PICKLE_KEY)
|
||||
@@ -482,7 +484,7 @@ mod tests {
|
||||
let room_id = room_id!("!test:example.org");
|
||||
let alice = get_olm_machine().await;
|
||||
|
||||
let dehydrated_device = alice.dehydrated_devices().create();
|
||||
let dehydrated_device = alice.dehydrated_devices().create().await.unwrap();
|
||||
|
||||
let mut request = dehydrated_device
|
||||
.keys_for_upload("Foo".to_owned(), PICKLE_KEY)
|
||||
|
||||
@@ -1069,7 +1069,7 @@ mod tests {
|
||||
identities::{LocalTrust, ReadOnlyDevice},
|
||||
olm::{Account, PrivateCrossSigningIdentity},
|
||||
session_manager::GroupSessionCache,
|
||||
store::{CryptoStoreWrapper, MemoryStore, Store},
|
||||
store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store},
|
||||
types::events::room::encrypted::{EncryptedEvent, RoomEncryptedEventContent},
|
||||
verification::VerificationMachine,
|
||||
};
|
||||
@@ -1113,7 +1113,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[cfg(feature = "automatic-room-key-forwarding")]
|
||||
fn test_gossip_machine(user_id: &UserId) -> GossipMachine {
|
||||
async fn gossip_machine_test_helper(user_id: &UserId) -> GossipMachine {
|
||||
let user_id = user_id.to_owned();
|
||||
let device_id = DeviceId::new();
|
||||
|
||||
@@ -1122,13 +1122,15 @@ mod tests {
|
||||
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())));
|
||||
let verification =
|
||||
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
|
||||
let store = Store::new(account, identity, store, verification);
|
||||
let store = Store::new(account.static_data().clone(), identity, store, verification);
|
||||
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
|
||||
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
GossipMachine::new(store, session_cache, Default::default())
|
||||
}
|
||||
|
||||
async fn get_machine() -> GossipMachine {
|
||||
async fn get_machine_test_helper() -> GossipMachine {
|
||||
let user_id = alice_id().to_owned();
|
||||
let account = Account::with_device_id(&user_id, alice_device_id());
|
||||
let device = ReadOnlyDevice::from_account(&account).await;
|
||||
@@ -1141,8 +1143,9 @@ mod tests {
|
||||
let verification =
|
||||
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
|
||||
|
||||
let store = Store::new(account, identity, store, verification);
|
||||
let store = Store::new(account.static_data().clone(), identity, store, verification);
|
||||
store.save_devices(&[device, another_device]).await.unwrap();
|
||||
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
GossipMachine::new(store, session_cache, Default::default())
|
||||
@@ -1154,16 +1157,16 @@ mod tests {
|
||||
create_sessions: bool,
|
||||
algorithm: EventEncryptionAlgorithm,
|
||||
) -> (GossipMachine, OutboundGroupSession, GossipMachine) {
|
||||
let alice_machine = get_machine().await;
|
||||
let alice_machine = get_machine_test_helper().await;
|
||||
let alice_device = ReadOnlyDevice::from_account(
|
||||
alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
|
||||
&alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let bob_machine = test_gossip_machine(other_machine_owner);
|
||||
let bob_machine = gossip_machine_test_helper(other_machine_owner).await;
|
||||
|
||||
let bob_device = ReadOnlyDevice::from_account(
|
||||
bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
|
||||
&*bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1295,14 +1298,14 @@ mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn create_machine() {
|
||||
let machine = get_machine().await;
|
||||
let machine = get_machine_test_helper().await;
|
||||
|
||||
assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn re_request_keys() {
|
||||
let machine = get_machine().await;
|
||||
let machine = get_machine_test_helper().await;
|
||||
let account = account();
|
||||
|
||||
let (outbound, session) = account.create_group_session_pair_with_defaults(room_id()).await;
|
||||
@@ -1325,7 +1328,7 @@ mod tests {
|
||||
#[async_test]
|
||||
#[cfg(feature = "automatic-room-key-forwarding")]
|
||||
async fn create_key_request() {
|
||||
let machine = get_machine().await;
|
||||
let machine = get_machine_test_helper().await;
|
||||
let account = account();
|
||||
let second_account = alice_2_account();
|
||||
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
|
||||
@@ -1357,7 +1360,7 @@ mod tests {
|
||||
#[async_test]
|
||||
#[cfg(feature = "automatic-room-key-forwarding")]
|
||||
async fn receive_forwarded_key() {
|
||||
let machine = get_machine().await;
|
||||
let machine = get_machine_test_helper().await;
|
||||
let account = account();
|
||||
|
||||
let second_account = alice_2_account();
|
||||
@@ -1466,7 +1469,7 @@ mod tests {
|
||||
#[async_test]
|
||||
#[cfg(feature = "automatic-room-key-forwarding")]
|
||||
async fn should_share_key_test() {
|
||||
let machine = get_machine().await;
|
||||
let machine = get_machine_test_helper().await;
|
||||
let account = account();
|
||||
|
||||
let own_device =
|
||||
@@ -1733,7 +1736,7 @@ mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn test_secret_share_cycle() {
|
||||
let alice_machine = get_machine().await;
|
||||
let alice_machine = get_machine_test_helper().await;
|
||||
|
||||
let second_account = alice_2_account();
|
||||
let alice_device = ReadOnlyDevice::from_account(&second_account).await;
|
||||
|
||||
@@ -869,7 +869,7 @@ impl ReadOnlyDevice {
|
||||
pub async fn from_machine_test_helper(
|
||||
machine: &OlmMachine,
|
||||
) -> Result<ReadOnlyDevice, crate::CryptoStoreError> {
|
||||
Ok(ReadOnlyDevice::from_account(machine.store().cache().await?.account().await?).await)
|
||||
Ok(ReadOnlyDevice::from_account(&*machine.store().cache().await?.account().await?).await)
|
||||
}
|
||||
|
||||
/// Create a `ReadOnlyDevice` from an `Account`
|
||||
|
||||
@@ -884,7 +884,7 @@ pub(crate) mod testing {
|
||||
identities::IdentityManager,
|
||||
machine::testing::response_from_file,
|
||||
olm::{Account, PrivateCrossSigningIdentity},
|
||||
store::{CryptoStoreWrapper, MemoryStore, Store},
|
||||
store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store},
|
||||
types::DeviceKeys,
|
||||
verification::VerificationMachine,
|
||||
UploadSigningKeysRequest,
|
||||
@@ -902,15 +902,20 @@ pub(crate) mod testing {
|
||||
device_id!("WSKKLTJZCL")
|
||||
}
|
||||
|
||||
pub(crate) async fn manager(user_id: &UserId, device_id: &DeviceId) -> IdentityManager {
|
||||
pub(crate) async fn manager_test_helper(
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> IdentityManager {
|
||||
let identity = PrivateCrossSigningIdentity::new(user_id.into()).await;
|
||||
let identity = Arc::new(Mutex::new(identity));
|
||||
let user_id = user_id.to_owned();
|
||||
let account = Account::with_device_id(&user_id, device_id);
|
||||
let static_account = account.static_data().clone();
|
||||
let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new()));
|
||||
let verification =
|
||||
VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone());
|
||||
let store = Store::new(account, identity, store, verification);
|
||||
VerificationMachine::new(static_account.clone(), identity.clone(), store.clone());
|
||||
let store = Store::new(static_account, identity, store, verification);
|
||||
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
|
||||
IdentityManager::new(store)
|
||||
}
|
||||
|
||||
@@ -1186,7 +1191,9 @@ pub(crate) mod tests {
|
||||
use serde_json::json;
|
||||
use stream_assert::{assert_closed, assert_pending, assert_ready};
|
||||
|
||||
use super::testing::{device_id, key_query, manager, other_key_query, other_user_id, user_id};
|
||||
use super::testing::{
|
||||
device_id, key_query, manager_test_helper, other_key_query, other_user_id, user_id,
|
||||
};
|
||||
use crate::{
|
||||
identities::manager::testing::{other_key_query_cross_signed, own_key_query},
|
||||
olm::PrivateCrossSigningIdentity,
|
||||
@@ -1211,7 +1218,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn test_tracked_users() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let alice = user_id!("@alice:example.org");
|
||||
|
||||
assert!(
|
||||
@@ -1231,13 +1238,13 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn test_manager_creation() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
assert!(manager.store.tracked_users().await.unwrap().is_empty())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_manager_key_query_response() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let other_user = other_user_id();
|
||||
let devices = manager.store.get_user_devices(other_user).await.unwrap();
|
||||
assert_eq!(devices.devices().count(), 0);
|
||||
@@ -1264,7 +1271,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn test_manager_own_key_query_response() -> anyhow::Result<()> {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let our_user = user_id();
|
||||
let devices = manager.store.get_user_devices(our_user).await?;
|
||||
assert_eq!(devices.devices().count(), 0);
|
||||
@@ -1299,7 +1306,7 @@ pub(crate) mod tests {
|
||||
#[async_test]
|
||||
async fn private_identity_invalidation_after_public_keys_change() {
|
||||
let user_id = user_id!("@example1:localhost");
|
||||
let manager = manager(user_id, "DEVICEID".into()).await;
|
||||
let manager = manager_test_helper(user_id, "DEVICEID".into()).await;
|
||||
|
||||
let identity_request = {
|
||||
let private_identity = manager.store.private_identity();
|
||||
@@ -1387,7 +1394,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn no_tracked_users_key_query_request() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
|
||||
assert!(
|
||||
manager.store.tracked_users().await.unwrap().is_empty(),
|
||||
@@ -1407,8 +1414,8 @@ pub(crate) mod tests {
|
||||
/// user is not removed from the list of outdated users when the
|
||||
/// response is received
|
||||
#[async_test]
|
||||
async fn invalidation_race_handling() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
async fn test_invalidation_race_handling() {
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let alice = other_user_id();
|
||||
manager.update_tracked_users([alice]).await.unwrap();
|
||||
|
||||
@@ -1436,7 +1443,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn failure_handling() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let alice = user_id!("@alice:example.org");
|
||||
|
||||
assert!(
|
||||
@@ -1476,7 +1483,7 @@ pub(crate) mod tests {
|
||||
#[async_test]
|
||||
async fn test_out_of_band_key_query() {
|
||||
// build the request
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let (reqid, req) = manager.build_key_query_for_users(vec![user_id()]);
|
||||
assert!(req.device_keys.contains_key(user_id()));
|
||||
|
||||
@@ -1495,7 +1502,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn devices_stream() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]);
|
||||
|
||||
let stream = manager.store.devices_stream();
|
||||
@@ -1509,7 +1516,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn identities_stream() {
|
||||
let manager = manager(user_id(), device_id()).await;
|
||||
let manager = manager_test_helper(user_id(), device_id()).await;
|
||||
let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]);
|
||||
|
||||
let stream = manager.store.user_identities_stream();
|
||||
@@ -1523,7 +1530,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn identities_stream_raw() {
|
||||
let mut manager = Some(manager(user_id(), device_id()).await);
|
||||
let mut manager = Some(manager_test_helper(user_id(), device_id()).await);
|
||||
let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![user_id()]);
|
||||
|
||||
let stream = manager.as_ref().unwrap().store.identities_stream_raw();
|
||||
@@ -1568,7 +1575,7 @@ pub(crate) mod tests {
|
||||
|
||||
#[async_test]
|
||||
async fn identities_stream_raw_signature_update() {
|
||||
let mut manager = Some(manager(user_id(), device_id()).await);
|
||||
let mut manager = Some(manager_test_helper(user_id(), device_id()).await);
|
||||
let (request_id, _) =
|
||||
manager.as_ref().unwrap().build_key_query_for_users(vec![other_user_id()]);
|
||||
|
||||
|
||||
@@ -147,7 +147,8 @@ impl OwnUserIdentity {
|
||||
}
|
||||
|
||||
let cache = self.store.cache().await?;
|
||||
cache.account().await?.sign_master_key(self.master_key.clone()).await
|
||||
let account = cache.account().await?;
|
||||
account.sign_master_key(self.master_key.clone()).await
|
||||
}
|
||||
|
||||
/// Send a verification request to our other devices.
|
||||
|
||||
@@ -64,6 +64,7 @@ use crate::{
|
||||
olm::{
|
||||
Account, CrossSigningStatus, EncryptionSettings, ExportedRoomKey, IdentityKeys,
|
||||
InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, SessionType,
|
||||
StaticAccountData,
|
||||
},
|
||||
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
|
||||
session_manager::{GroupSessionManager, SessionManager},
|
||||
@@ -168,25 +169,23 @@ impl OlmMachine {
|
||||
) -> Result<OlmMachine, DehydrationError> {
|
||||
let account =
|
||||
Account::rehydrate(pickle_key, self.user_id(), device_id, device_data).await?;
|
||||
let static_account = account.static_data().clone();
|
||||
|
||||
let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new()));
|
||||
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
|
||||
|
||||
Ok(Self::new_helper(device_id, store, account, self.store().private_identity()))
|
||||
Ok(Self::new_helper(device_id, store, static_account, self.store().private_identity()))
|
||||
}
|
||||
|
||||
fn new_helper(
|
||||
device_id: &DeviceId,
|
||||
store: Arc<CryptoStoreWrapper>,
|
||||
account: Account,
|
||||
account: StaticAccountData,
|
||||
user_identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
||||
) -> Self {
|
||||
let verification_machine = VerificationMachine::new(
|
||||
account.static_data().clone(),
|
||||
user_identity.clone(),
|
||||
store.clone(),
|
||||
);
|
||||
let store =
|
||||
Store::new(account.clone(), user_identity.clone(), store, verification_machine.clone());
|
||||
let verification_machine =
|
||||
VerificationMachine::new(account.clone(), user_identity.clone(), store.clone());
|
||||
let store = Store::new(account, user_identity.clone(), store, verification_machine.clone());
|
||||
|
||||
let group_session_manager = GroupSessionManager::new(store.clone());
|
||||
|
||||
@@ -248,7 +247,8 @@ impl OlmMachine {
|
||||
store: impl IntoCryptoStore,
|
||||
) -> StoreResult<Self> {
|
||||
let store = store.into_crypto_store();
|
||||
let account = match store.load_account().await? {
|
||||
|
||||
let static_account = match store.load_account().await? {
|
||||
Some(account) => {
|
||||
if user_id != account.user_id() || device_id != account.device_id() {
|
||||
return Err(CryptoStoreError::MismatchedAccount {
|
||||
@@ -262,10 +262,12 @@ impl OlmMachine {
|
||||
.record("curve25519_key", display(account.identity_keys().curve25519));
|
||||
debug!("Restored an Olm account");
|
||||
|
||||
account
|
||||
account.static_data().clone()
|
||||
}
|
||||
|
||||
None => {
|
||||
let account = Account::with_device_id(user_id, device_id);
|
||||
let static_account = account.static_data().clone();
|
||||
|
||||
Span::current()
|
||||
.record("ed25519_key", display(account.identity_keys().ed25519))
|
||||
@@ -283,13 +285,11 @@ impl OlmMachine {
|
||||
..Default::default()
|
||||
};
|
||||
store.save_changes(changes).await?;
|
||||
store
|
||||
.save_pending_changes(PendingChanges { account: Some(account.clone()) })
|
||||
.await?;
|
||||
store.save_pending_changes(PendingChanges { account: Some(account) }).await?;
|
||||
|
||||
debug!("Created a new Olm account");
|
||||
|
||||
account
|
||||
static_account
|
||||
}
|
||||
};
|
||||
|
||||
@@ -310,7 +310,7 @@ impl OlmMachine {
|
||||
|
||||
let identity = Arc::new(Mutex::new(identity));
|
||||
let store = Arc::new(CryptoStoreWrapper::new(user_id, store));
|
||||
Ok(OlmMachine::new_helper(device_id, store, account, identity))
|
||||
Ok(OlmMachine::new_helper(device_id, store, static_account, identity))
|
||||
}
|
||||
|
||||
/// Get the crypto store associated with this `OlmMachine` instance.
|
||||
@@ -2028,7 +2028,8 @@ impl OlmMachine {
|
||||
#[cfg(any(feature = "testing", test))]
|
||||
pub async fn uploaded_key_count(&self) -> Result<u64, CryptoStoreError> {
|
||||
let cache = self.inner.store.cache().await?;
|
||||
Ok(cache.account().await?.uploaded_key_count())
|
||||
let account = cache.account().await?;
|
||||
Ok(account.uploaded_key_count())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2436,6 +2437,7 @@ pub(crate) mod tests {
|
||||
|
||||
let mut request =
|
||||
machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload");
|
||||
drop(account); // Release the account lock.
|
||||
|
||||
let one_time_key: SignedKey = request
|
||||
.one_time_keys
|
||||
@@ -2461,16 +2463,27 @@ pub(crate) mod tests {
|
||||
);
|
||||
ret.unwrap();
|
||||
|
||||
let mut response = keys_upload_response();
|
||||
response.one_time_key_counts.insert(
|
||||
DeviceKeyAlgorithm::SignedCurve25519,
|
||||
(account.max_one_time_keys().await).try_into().unwrap(),
|
||||
);
|
||||
let response = {
|
||||
let cache = machine.store().cache().await.unwrap();
|
||||
let account = cache.account().await.unwrap();
|
||||
|
||||
let mut response = keys_upload_response();
|
||||
response.one_time_key_counts.insert(
|
||||
DeviceKeyAlgorithm::SignedCurve25519,
|
||||
(account.max_one_time_keys().await).try_into().unwrap(),
|
||||
);
|
||||
|
||||
response
|
||||
};
|
||||
|
||||
machine.receive_keys_upload_response(&response).await.unwrap();
|
||||
|
||||
let ret = machine.keys_for_upload(&account).await;
|
||||
assert!(ret.is_none());
|
||||
{
|
||||
let cache = machine.store().cache().await.unwrap();
|
||||
let account = cache.account().await.unwrap();
|
||||
let ret = machine.keys_for_upload(&account).await;
|
||||
assert!(ret.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
||||
@@ -314,7 +314,6 @@ impl StaticAccountData {
|
||||
///
|
||||
/// An account is the central identity for encrypted communication between two
|
||||
/// devices.
|
||||
#[derive(Clone)]
|
||||
pub struct Account {
|
||||
pub(crate) static_data: StaticAccountData,
|
||||
/// `vodozemac` account.
|
||||
@@ -1356,6 +1355,18 @@ impl Account {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Test-only.
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
#[doc(hidden)]
|
||||
pub fn clone_for_testing(&self) -> Self {
|
||||
Self {
|
||||
static_data: self.static_data.clone(),
|
||||
inner: self.inner.clone(),
|
||||
shared: self.shared.clone(),
|
||||
uploaded_signed_key_count: self.uploaded_signed_key_count.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Account {
|
||||
|
||||
@@ -484,7 +484,7 @@ mod tests {
|
||||
identities::{IdentityManager, ReadOnlyDevice},
|
||||
olm::{Account, PrivateCrossSigningIdentity},
|
||||
session_manager::GroupSessionCache,
|
||||
store::{CryptoStoreWrapper, MemoryStore, Store},
|
||||
store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store},
|
||||
verification::VerificationMachine,
|
||||
};
|
||||
|
||||
@@ -531,7 +531,6 @@ mod tests {
|
||||
let user_id = user_id();
|
||||
let device_id = device_id();
|
||||
|
||||
let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
|
||||
let account = Account::with_device_id(user_id, device_id);
|
||||
let store = Arc::new(CryptoStoreWrapper::new(user_id, MemoryStore::new()));
|
||||
let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
|
||||
@@ -541,15 +540,12 @@ mod tests {
|
||||
store.clone(),
|
||||
);
|
||||
|
||||
let store = Store::new(account.clone(), identity, store, verification);
|
||||
{
|
||||
// Perform a dummy transaction to sync in-memory cache with the db.
|
||||
let tr = store.transaction().await.unwrap();
|
||||
tr.commit().await.unwrap();
|
||||
}
|
||||
let store = Store::new(account.static_data().clone(), identity, store, verification);
|
||||
store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
|
||||
|
||||
let session_cache = GroupSessionCache::new(store.clone());
|
||||
|
||||
let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
|
||||
let key_request =
|
||||
GossipMachine::new(store.clone(), session_cache, users_for_key_claim.clone());
|
||||
|
||||
@@ -656,7 +652,7 @@ mod tests {
|
||||
|
||||
let (_, mut session) = {
|
||||
let cache = manager.store.cache().await.unwrap();
|
||||
let manager_account = cache.account().await.unwrap();
|
||||
let manager_account = &*cache.account().await.unwrap();
|
||||
bob.create_session_for(manager_account).await
|
||||
};
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
let store = get_store(name, None).await;
|
||||
let account = get_account();
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
(account, store)
|
||||
}
|
||||
@@ -124,7 +124,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
let store = get_store("load_account", None).await;
|
||||
let account = get_account();
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||
let loaded_account = loaded_account.unwrap();
|
||||
@@ -138,7 +138,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
get_store("load_account_with_passphrase", Some("secret_passphrase")).await;
|
||||
let account = get_account();
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||
let loaded_account = loaded_account.unwrap();
|
||||
@@ -151,12 +151,12 @@ macro_rules! cryptostore_integration_tests {
|
||||
let store = get_store("save_and_share_account", None).await;
|
||||
let account = get_account();
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
account.mark_as_shared();
|
||||
account.update_uploaded_key_count(50);
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
let loaded_account = store.load_account().await.expect("Can't load account");
|
||||
let loaded_account = loaded_account.unwrap();
|
||||
@@ -169,7 +169,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
async fn load_sessions() {
|
||||
let store = get_store("load_sessions", None).await;
|
||||
let (account, session) = get_account_and_session().await;
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
|
||||
|
||||
@@ -193,7 +193,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
let sender_key = session.sender_key.to_base64();
|
||||
let session_id = session.session_id().to_owned();
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account");
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account");
|
||||
|
||||
let changes = Changes { sessions: vec![session.clone()], ..Default::default() };
|
||||
store.save_changes(changes).await.unwrap();
|
||||
@@ -491,7 +491,7 @@ macro_rules! cryptostore_integration_tests {
|
||||
|
||||
let account = Account::with_device_id(&user_id, device_id);
|
||||
|
||||
store.save_pending_changes(PendingChanges { account: Some(account.clone()), })
|
||||
store.save_pending_changes(PendingChanges { account: Some(account), })
|
||||
.await
|
||||
.expect("Can't save account");
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ fn encode_key_info(info: &SecretInfo) -> String {
|
||||
/// An in-memory only store that will forget all the E2EE key once it's dropped.
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryStore {
|
||||
account: StdRwLock<Option<Account>>,
|
||||
sessions: SessionStore,
|
||||
inbound_group_sessions: GroupSessionStore,
|
||||
olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
|
||||
@@ -70,6 +71,7 @@ pub struct MemoryStore {
|
||||
impl Default for MemoryStore {
|
||||
fn default() -> Self {
|
||||
MemoryStore {
|
||||
account: Default::default(),
|
||||
sessions: SessionStore::new(),
|
||||
inbound_group_sessions: GroupSessionStore::new(),
|
||||
olm_hashes: Default::default(),
|
||||
@@ -126,7 +128,7 @@ impl CryptoStore for MemoryStore {
|
||||
type Error = Infallible;
|
||||
|
||||
async fn load_account(&self) -> Result<Option<Account>> {
|
||||
Ok(None)
|
||||
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.clone_for_testing()))
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
@@ -137,8 +139,11 @@ impl CryptoStore for MemoryStore {
|
||||
Ok(self.next_batch_token.read().await.clone())
|
||||
}
|
||||
|
||||
async fn save_pending_changes(&self, _changes: PendingChanges) -> Result<()> {
|
||||
// TODO(bnjbvr) why didn't save_changes save the account?
|
||||
async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
|
||||
if let Some(account) = changes.account {
|
||||
*self.account.write().unwrap() = Some(account);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ use ruma::{
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio::sync::{Mutex, MutexGuard, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
|
||||
use zeroize::Zeroize;
|
||||
@@ -107,9 +107,11 @@ pub struct Store {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StoreCache {
|
||||
store: Arc<CryptoStoreWrapper>,
|
||||
|
||||
tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
|
||||
tracked_user_loading_lock: RwLock<bool>,
|
||||
account: Account,
|
||||
account: Mutex<Option<Account>>,
|
||||
}
|
||||
|
||||
impl StoreCache {
|
||||
@@ -119,8 +121,19 @@ impl StoreCache {
|
||||
///
|
||||
/// Note there should always be an account stored at least in the store, so this doesn't return
|
||||
/// an `Option`.
|
||||
pub async fn account(&self) -> Result<&Account> {
|
||||
Ok(&self.account)
|
||||
pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
|
||||
let mut guard = self.account.lock().await;
|
||||
if guard.is_some() {
|
||||
Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
|
||||
} else {
|
||||
match self.store.load_account().await? {
|
||||
Some(account) => {
|
||||
*guard = Some(account);
|
||||
Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
|
||||
}
|
||||
None => Err(CryptoStoreError::AccountUnset),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,7 +176,9 @@ impl StoreTransaction {
|
||||
/// Gets a `Account` for update.
|
||||
pub async fn account(&mut self) -> Result<&mut Account> {
|
||||
if self.changes.account.is_none() {
|
||||
self.changes.account = Some(self.cache.account.clone());
|
||||
// Make sure the cache loaded the account.
|
||||
let _ = self.cache.account().await?;
|
||||
self.changes.account = self.cache.account.lock().await.take();
|
||||
}
|
||||
Ok(self.changes.account.as_mut().unwrap())
|
||||
}
|
||||
@@ -604,23 +619,24 @@ impl From<&InboundGroupSession> for RoomKeyInfo {
|
||||
impl Store {
|
||||
/// Create a new Store.
|
||||
pub(crate) fn new(
|
||||
account: Account,
|
||||
account: StaticAccountData,
|
||||
identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
|
||||
store: Arc<CryptoStoreWrapper>,
|
||||
verification_machine: VerificationMachine,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(StoreInner {
|
||||
static_account: account.static_data().clone(),
|
||||
static_account: account,
|
||||
identity,
|
||||
store,
|
||||
store: store.clone(),
|
||||
verification_machine,
|
||||
users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()),
|
||||
users_for_key_query_condvar: Condvar::new(),
|
||||
cache: Arc::new(StoreCache {
|
||||
store,
|
||||
tracked_users: Default::default(),
|
||||
tracked_user_loading_lock: Default::default(),
|
||||
account,
|
||||
account: Default::default(),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user