mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-12 18:21:21 -04:00
refactor(crypto): Use the simplified locks across the crypto crate
This commit is contained in:
@@ -12,8 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk_common::locks::Mutex;
|
||||
use ruma::{
|
||||
api::client::backup::{EncryptedSessionDataInit, KeyBackupData, KeyBackupDataInit},
|
||||
serde::Base64,
|
||||
@@ -87,7 +88,7 @@ impl MegolmV1BackupKey {
|
||||
|
||||
/// Get the backup version that this key is used with, if any.
|
||||
pub fn backup_version(&self) -> Option<String> {
|
||||
self.inner.version.lock().unwrap().clone()
|
||||
self.inner.version.lock().clone()
|
||||
}
|
||||
|
||||
/// Set the backup version that this `MegolmV1BackupKey` will be used with.
|
||||
@@ -95,7 +96,7 @@ impl MegolmV1BackupKey {
|
||||
/// The key won't be able to encrypt room keys unless a version has been
|
||||
/// set.
|
||||
pub fn set_version(&self, version: String) {
|
||||
*self.inner.version.lock().unwrap() = Some(version);
|
||||
*self.inner.version.lock() = Some(version);
|
||||
}
|
||||
|
||||
/// Export the given inbound group session, and encrypt the data, ready for
|
||||
|
||||
@@ -25,10 +25,11 @@ use std::{
|
||||
mem,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, RwLock as StdRwLock,
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use ruma::{
|
||||
api::client::keys::claim_keys::v3::Request as KeysClaimRequest,
|
||||
events::secret::request::{
|
||||
@@ -168,14 +169,13 @@ impl GossipMachine {
|
||||
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
|
||||
let mut key_requests = self.load_outgoing_requests().await?;
|
||||
let key_forwards: Vec<OutgoingRequest> =
|
||||
self.inner.outgoing_requests.read().unwrap().values().cloned().collect();
|
||||
self.inner.outgoing_requests.read().values().cloned().collect();
|
||||
key_requests.extend(key_forwards);
|
||||
|
||||
let users_for_key_claim: BTreeMap<_, _> = self
|
||||
.inner
|
||||
.users_for_key_claim
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|(key, value)| {
|
||||
let device_map = value
|
||||
@@ -213,7 +213,7 @@ impl GossipMachine {
|
||||
trace!("Received a secret request event from ourselves, ignoring")
|
||||
} else {
|
||||
let request_info = event.to_request_info();
|
||||
self.inner.incoming_key_requests.write().unwrap().insert(request_info, event);
|
||||
self.inner.incoming_key_requests.write().insert(request_info, event);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,8 +229,7 @@ impl GossipMachine {
|
||||
) -> OlmResult<Vec<Session>> {
|
||||
let mut changed_sessions = Vec::new();
|
||||
|
||||
let incoming_key_requests =
|
||||
mem::take(&mut *self.inner.incoming_key_requests.write().unwrap());
|
||||
let incoming_key_requests = mem::take(&mut *self.inner.incoming_key_requests.write());
|
||||
|
||||
for event in incoming_key_requests.values() {
|
||||
if let Some(s) = match event {
|
||||
@@ -254,7 +253,6 @@ impl GossipMachine {
|
||||
self.inner
|
||||
.users_for_key_claim
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(device.user_id().to_owned())
|
||||
.or_default()
|
||||
.insert(device.device_id().into());
|
||||
@@ -275,7 +273,7 @@ impl GossipMachine {
|
||||
/// * `device_id` - The device ID of the device that got the Olm session.
|
||||
pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
if let Entry::Occupied(mut e) =
|
||||
self.inner.users_for_key_claim.write().unwrap().entry(user_id.to_owned())
|
||||
self.inner.users_for_key_claim.write().entry(user_id.to_owned())
|
||||
{
|
||||
e.get_mut().remove(device_id);
|
||||
|
||||
@@ -284,7 +282,7 @@ impl GossipMachine {
|
||||
}
|
||||
}
|
||||
|
||||
let mut incoming_key_requests = self.inner.incoming_key_requests.write().unwrap();
|
||||
let mut incoming_key_requests = self.inner.incoming_key_requests.write();
|
||||
for (key, event) in self.inner.wait_queue.remove(user_id, device_id) {
|
||||
incoming_key_requests.entry(key).or_insert(event);
|
||||
}
|
||||
@@ -555,7 +553,7 @@ impl GossipMachine {
|
||||
request_id: request.txn_id.clone(),
|
||||
request: Arc::new(request.into()),
|
||||
};
|
||||
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
|
||||
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
|
||||
|
||||
Ok(used_session)
|
||||
}
|
||||
@@ -581,7 +579,7 @@ impl GossipMachine {
|
||||
request_id: request.txn_id.clone(),
|
||||
request: Arc::new(request.into()),
|
||||
};
|
||||
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
|
||||
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
|
||||
|
||||
Ok(used_session)
|
||||
}
|
||||
@@ -824,7 +822,7 @@ impl GossipMachine {
|
||||
self.save_outgoing_key_info(info).await?;
|
||||
}
|
||||
|
||||
self.inner.outgoing_requests.write().unwrap().remove(id);
|
||||
self.inner.outgoing_requests.write().remove(id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -840,13 +838,13 @@ impl GossipMachine {
|
||||
"Successfully received a secret, removing the request"
|
||||
);
|
||||
|
||||
self.inner.outgoing_requests.write().unwrap().remove(&key_info.request_id);
|
||||
self.inner.outgoing_requests.write().remove(&key_info.request_id);
|
||||
// TODO return the key info instead of deleting it so the sync handler
|
||||
// can delete it in one transaction.
|
||||
self.delete_key_info(key_info).await?;
|
||||
|
||||
let request = key_info.to_cancellation(self.device_id());
|
||||
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
|
||||
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1511,7 +1509,6 @@ mod tests {
|
||||
.inner
|
||||
.outgoing_requests
|
||||
.read()
|
||||
.unwrap()
|
||||
.first_key_value()
|
||||
.map(|(_, r)| r.request_id.clone())
|
||||
.unwrap();
|
||||
@@ -1692,7 +1689,7 @@ mod tests {
|
||||
alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap();
|
||||
|
||||
// Bob doesn't have any outgoing requests.
|
||||
assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(bob_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
// Receive the room key request from alice.
|
||||
bob_machine.receive_incoming_key_request(&event);
|
||||
@@ -1702,7 +1699,7 @@ mod tests {
|
||||
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
|
||||
}
|
||||
// Now bob does have an outgoing request.
|
||||
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(!bob_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
// Get the request and convert it to a encrypted to-device event.
|
||||
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
|
||||
@@ -1774,7 +1771,7 @@ mod tests {
|
||||
alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap();
|
||||
|
||||
// Bob doesn't have any outgoing requests.
|
||||
assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(bob_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
// Receive the room key request from alice.
|
||||
bob_machine.receive_incoming_key_request(&event);
|
||||
@@ -1783,7 +1780,7 @@ mod tests {
|
||||
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
|
||||
}
|
||||
// Now bob does have an outgoing request.
|
||||
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(!bob_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
// Get the request and convert it to a encrypted to-device event.
|
||||
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
|
||||
@@ -1875,13 +1872,13 @@ mod tests {
|
||||
};
|
||||
|
||||
// No secret found
|
||||
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
|
||||
alice_machine.receive_incoming_secret_request(&event);
|
||||
{
|
||||
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
|
||||
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
|
||||
}
|
||||
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
// No device found
|
||||
alice_machine.inner.store.reset_cross_signing_identity().await;
|
||||
@@ -1890,7 +1887,7 @@ mod tests {
|
||||
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
|
||||
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
|
||||
}
|
||||
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
alice_machine.inner.store.save_device_data(&[bob_device]).await.unwrap();
|
||||
|
||||
@@ -1901,7 +1898,7 @@ mod tests {
|
||||
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
|
||||
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
|
||||
}
|
||||
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
let event = RumaToDeviceEvent {
|
||||
sender: alice_id().to_owned(),
|
||||
@@ -1918,7 +1915,7 @@ mod tests {
|
||||
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
|
||||
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
|
||||
}
|
||||
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
|
||||
|
||||
// We need a trusted device, otherwise we won't serve secrets
|
||||
alice_device.set_trust_state(LocalTrust::Verified);
|
||||
@@ -1929,7 +1926,7 @@ mod tests {
|
||||
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
|
||||
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
|
||||
}
|
||||
assert!(!alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
|
||||
assert!(!alice_machine.inner.outgoing_requests.read().is_empty());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
@@ -2053,7 +2050,7 @@ mod tests {
|
||||
|
||||
// Bob doesn't have any outgoing requests.
|
||||
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
|
||||
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
|
||||
assert!(bob_machine.inner.users_for_key_claim.read().is_empty());
|
||||
assert!(bob_machine.inner.wait_queue.is_empty());
|
||||
|
||||
// Receive the room key request from alice.
|
||||
@@ -2068,7 +2065,7 @@ mod tests {
|
||||
bob_machine.outgoing_to_device_requests().await.unwrap()[0].request(),
|
||||
AnyOutgoingRequest::KeysClaim(_)
|
||||
);
|
||||
assert!(!bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
|
||||
assert!(!bob_machine.inner.users_for_key_claim.read().is_empty());
|
||||
assert!(!bob_machine.inner.wait_queue.is_empty());
|
||||
|
||||
let (alice_session, bob_session) = alice_machine
|
||||
@@ -2096,7 +2093,7 @@ mod tests {
|
||||
bob_machine.inner.store.save_sessions(&[bob_session]).await.unwrap();
|
||||
|
||||
bob_machine.retry_keyshare(alice_id(), alice_device_id());
|
||||
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
|
||||
assert!(bob_machine.inner.users_for_key_claim.read().is_empty());
|
||||
{
|
||||
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
|
||||
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
|
||||
|
||||
@@ -16,10 +16,11 @@ mod machine;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub(crate) use machine::GossipMachine;
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use ruma::{
|
||||
events::{
|
||||
room_key_request::{Action, ToDeviceRoomKeyRequestEventContent},
|
||||
@@ -323,7 +324,7 @@ impl WaitQueue {
|
||||
|
||||
#[cfg(all(test, feature = "automatic-room-key-forwarding"))]
|
||||
fn is_empty(&self) -> bool {
|
||||
let read_guard = self.inner.read().unwrap();
|
||||
let read_guard = self.inner.read();
|
||||
read_guard.requests_ids_waiting.is_empty()
|
||||
&& read_guard.requests_waiting_for_session.is_empty()
|
||||
}
|
||||
@@ -337,13 +338,14 @@ impl WaitQueue {
|
||||
);
|
||||
let ids_waiting_key = (device.user_id().to_owned(), device.device_id().into());
|
||||
|
||||
let mut write_guard = self.inner.write().unwrap();
|
||||
let mut write_guard = self.inner.write();
|
||||
write_guard.requests_waiting_for_session.insert(requests_waiting_key, event);
|
||||
write_guard.requests_ids_waiting.entry(ids_waiting_key).or_default().insert(request_id);
|
||||
}
|
||||
|
||||
fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Vec<(RequestInfo, RequestEvent)> {
|
||||
let mut write_guard = self.inner.write().unwrap();
|
||||
let mut write_guard = self.inner.write();
|
||||
|
||||
write_guard
|
||||
.requests_ids_waiting
|
||||
.remove(&(user_id.to_owned(), device_id.into()))
|
||||
|
||||
@@ -17,11 +17,11 @@ use std::{
|
||||
ops::Deref,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, RwLock,
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use matrix_sdk_common::deserialized_responses::WithheldCode;
|
||||
use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock};
|
||||
use ruma::{
|
||||
api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest,
|
||||
events::{key::verification::VerificationMethod, AnyToDeviceEventContent},
|
||||
@@ -470,7 +470,7 @@ impl Device {
|
||||
) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
|
||||
let (used_session, raw_encrypted) = self.encrypt(event_type, content).await?;
|
||||
|
||||
// perist the used session
|
||||
// Persist the used session
|
||||
self.verification_machine
|
||||
.store
|
||||
.save_changes(Changes { sessions: vec![used_session], ..Default::default() })
|
||||
@@ -626,7 +626,7 @@ impl DeviceData {
|
||||
|
||||
/// Get the trust state of the device.
|
||||
pub fn local_trust_state(&self) -> LocalTrust {
|
||||
*self.trust_state.read().unwrap()
|
||||
*self.trust_state.read()
|
||||
}
|
||||
|
||||
/// Is the device locally marked as trusted.
|
||||
@@ -646,7 +646,7 @@ impl DeviceData {
|
||||
/// Note: This should only done in the crypto store where the trust state
|
||||
/// can be stored.
|
||||
pub(crate) fn set_trust_state(&self, state: LocalTrust) {
|
||||
*self.trust_state.write().unwrap() = state;
|
||||
*self.trust_state.write() = state;
|
||||
}
|
||||
|
||||
pub(crate) fn mark_withheld_code_as_sent(&self) {
|
||||
|
||||
@@ -17,11 +17,12 @@ use std::{
|
||||
ops::{Deref, DerefMut},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, RwLock,
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use as_variant::as_variant;
|
||||
use matrix_sdk_common::locks::RwLock;
|
||||
use ruma::{
|
||||
api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest,
|
||||
events::{
|
||||
@@ -684,7 +685,7 @@ impl From<OtherUserIdentityData> for OtherUserIdentityDataSerializer {
|
||||
user_id: value.user_id.clone(),
|
||||
master_key: value.master_key().to_owned(),
|
||||
self_signing_key: value.self_signing_key().to_owned(),
|
||||
pinned_master_key: value.pinned_master_key.read().unwrap().clone(),
|
||||
pinned_master_key: value.pinned_master_key.read().clone(),
|
||||
previously_verified: value.previously_verified.load(Ordering::SeqCst),
|
||||
};
|
||||
OtherUserIdentityDataSerializer {
|
||||
@@ -787,7 +788,7 @@ impl OtherUserIdentityData {
|
||||
/// which is not verified and is in pin violation. See
|
||||
/// [`OtherUserIdentity::identity_needs_user_approval`].
|
||||
pub(crate) fn pin(&self) {
|
||||
let mut m = self.pinned_master_key.write().unwrap();
|
||||
let mut m = self.pinned_master_key.write();
|
||||
*m = self.master_key.as_ref().clone()
|
||||
}
|
||||
|
||||
@@ -828,7 +829,7 @@ impl OtherUserIdentityData {
|
||||
/// accept and pin the new identity, perform a verification, or
|
||||
/// stop communications.
|
||||
pub(crate) fn has_pin_violation(&self) -> bool {
|
||||
let pinned_master_key = self.pinned_master_key.read().unwrap();
|
||||
let pinned_master_key = self.pinned_master_key.read();
|
||||
pinned_master_key.get_first_key() != self.master_key().get_first_key()
|
||||
}
|
||||
|
||||
@@ -858,7 +859,7 @@ impl OtherUserIdentityData {
|
||||
// the previous pinned master key.
|
||||
// This identity will have a pin violation until the new master key is pinned
|
||||
// (see `has_pin_violation()`).
|
||||
let pinned_master_key = self.pinned_master_key.read().unwrap().clone();
|
||||
let pinned_master_key = self.pinned_master_key.read().clone();
|
||||
|
||||
// Check if the new master_key is signed by our own **verified**
|
||||
// user_signing_key. If the identity was verified we remember it.
|
||||
@@ -947,7 +948,7 @@ impl PartialEq for OwnUserIdentityData {
|
||||
&& self.master_key == other.master_key
|
||||
&& self.self_signing_key == other.self_signing_key
|
||||
&& self.user_signing_key == other.user_signing_key
|
||||
&& *self.verified.read().unwrap() == *other.verified.read().unwrap()
|
||||
&& *self.verified.read() == *other.verified.read()
|
||||
&& self.master_key.signatures() == other.master_key.signatures()
|
||||
}
|
||||
}
|
||||
@@ -1067,12 +1068,12 @@ impl OwnUserIdentityData {
|
||||
|
||||
/// Mark our identity as verified.
|
||||
pub fn mark_as_verified(&self) {
|
||||
*self.verified.write().unwrap() = OwnUserIdentityVerifiedState::Verified;
|
||||
*self.verified.write() = OwnUserIdentityVerifiedState::Verified;
|
||||
}
|
||||
|
||||
/// Mark our identity as unverified.
|
||||
pub(crate) fn mark_as_unverified(&self) {
|
||||
let mut guard = self.verified.write().unwrap();
|
||||
let mut guard = self.verified.write();
|
||||
if *guard == OwnUserIdentityVerifiedState::Verified {
|
||||
*guard = OwnUserIdentityVerifiedState::VerificationViolation;
|
||||
}
|
||||
@@ -1080,7 +1081,7 @@ impl OwnUserIdentityData {
|
||||
|
||||
/// Check if our identity is verified.
|
||||
pub fn is_verified(&self) -> bool {
|
||||
*self.verified.read().unwrap() == OwnUserIdentityVerifiedState::Verified
|
||||
*self.verified.read() == OwnUserIdentityVerifiedState::Verified
|
||||
}
|
||||
|
||||
/// True if we verified our own identity at some point in the past.
|
||||
@@ -1089,7 +1090,7 @@ impl OwnUserIdentityData {
|
||||
/// [`OwnUserIdentityData::withdraw_verification()`].
|
||||
pub fn was_previously_verified(&self) -> bool {
|
||||
matches!(
|
||||
*self.verified.read().unwrap(),
|
||||
*self.verified.read(),
|
||||
OwnUserIdentityVerifiedState::Verified
|
||||
| OwnUserIdentityVerifiedState::VerificationViolation
|
||||
)
|
||||
@@ -1101,7 +1102,7 @@ impl OwnUserIdentityData {
|
||||
/// reported to the user. In order to remove this notice users have to
|
||||
/// verify again or to withdraw the verification requirement.
|
||||
pub fn withdraw_verification(&self) {
|
||||
let mut guard = self.verified.write().unwrap();
|
||||
let mut guard = self.verified.write();
|
||||
if *guard == OwnUserIdentityVerifiedState::VerificationViolation {
|
||||
*guard = OwnUserIdentityVerifiedState::NeverVerified;
|
||||
}
|
||||
@@ -1117,7 +1118,7 @@ impl OwnUserIdentityData {
|
||||
/// - Or by withdrawing the verification requirement
|
||||
/// [`OwnUserIdentity::withdraw_verification`].
|
||||
pub fn has_verification_violation(&self) -> bool {
|
||||
*self.verified.read().unwrap() == OwnUserIdentityVerifiedState::VerificationViolation
|
||||
*self.verified.read() == OwnUserIdentityVerifiedState::VerificationViolation
|
||||
}
|
||||
|
||||
/// Update the identity with a new master key and self signing key.
|
||||
@@ -1523,7 +1524,7 @@ pub(crate) mod tests {
|
||||
});
|
||||
let migrated: OtherUserIdentityData = serde_json::from_value(serialized_value).unwrap();
|
||||
|
||||
let pinned_master_key = migrated.pinned_master_key.read().unwrap();
|
||||
let pinned_master_key = migrated.pinned_master_key.read();
|
||||
assert_eq!(*pinned_master_key, migrated.master_key().clone());
|
||||
|
||||
// Serialize back
|
||||
@@ -1547,12 +1548,12 @@ pub(crate) mod tests {
|
||||
// Set `"verified": false`
|
||||
*json.get_mut("verified").unwrap() = false.into();
|
||||
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
|
||||
assert_eq!(*id.verified.read().unwrap(), OwnUserIdentityVerifiedState::NeverVerified);
|
||||
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::NeverVerified);
|
||||
|
||||
// Tweak the json to have `"verified": true`, and repeat
|
||||
*json.get_mut("verified").unwrap() = true.into();
|
||||
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
|
||||
assert_eq!(*id.verified.read().unwrap(), OwnUserIdentityVerifiedState::Verified);
|
||||
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::Verified);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1565,10 +1566,7 @@ pub(crate) mod tests {
|
||||
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
|
||||
|
||||
// Then the value is correctly populated
|
||||
assert_eq!(
|
||||
*id.verified.read().unwrap(),
|
||||
OwnUserIdentityVerifiedState::VerificationViolation
|
||||
);
|
||||
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::VerificationViolation);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1581,10 +1579,7 @@ pub(crate) mod tests {
|
||||
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
|
||||
|
||||
// Then the old value is re-interpreted as VerificationViolation
|
||||
assert_eq!(
|
||||
*id.verified.read().unwrap(),
|
||||
OwnUserIdentityVerifiedState::VerificationViolation
|
||||
);
|
||||
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::VerificationViolation);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
@@ -25,6 +25,7 @@ use matrix_sdk_common::{
|
||||
UnableToDecryptReason, UnsignedDecryptionResult, UnsignedEventLocation, VerificationLevel,
|
||||
VerificationState,
|
||||
},
|
||||
locks::RwLock as StdRwLock,
|
||||
BoxFuture,
|
||||
};
|
||||
use ruma::{
|
||||
|
||||
@@ -18,12 +18,12 @@ use std::{
|
||||
fmt,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU64, Ordering},
|
||||
Arc, RwLock as StdRwLock,
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::deserialized_responses::WithheldCode;
|
||||
use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock};
|
||||
use ruma::{
|
||||
events::{
|
||||
room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility},
|
||||
@@ -274,7 +274,7 @@ impl OutboundGroupSession {
|
||||
request: Arc<ToDeviceRequest>,
|
||||
share_infos: ShareInfoSet,
|
||||
) {
|
||||
self.to_share_with_set.write().unwrap().insert(request_id, (request, share_infos));
|
||||
self.to_share_with_set.write().insert(request_id, (request, share_infos));
|
||||
}
|
||||
|
||||
/// Create a new `m.room_key.withheld` event content with the given code for
|
||||
@@ -310,7 +310,7 @@ impl OutboundGroupSession {
|
||||
) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
|
||||
let mut no_olm_devices = BTreeMap::new();
|
||||
|
||||
let removed = self.to_share_with_set.write().unwrap().remove(request_id);
|
||||
let removed = self.to_share_with_set.write().remove(request_id);
|
||||
if let Some((to_device, request)) = removed {
|
||||
let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
|
||||
.iter()
|
||||
@@ -332,10 +332,10 @@ impl OutboundGroupSession {
|
||||
.collect();
|
||||
no_olm_devices.insert(user_id.to_owned(), no_olms);
|
||||
|
||||
self.shared_with_set.write().unwrap().entry(user_id).or_default().extend(info);
|
||||
self.shared_with_set.write().entry(user_id).or_default().extend(info);
|
||||
}
|
||||
|
||||
if self.to_share_with_set.read().unwrap().is_empty() {
|
||||
if self.to_share_with_set.read().is_empty() {
|
||||
debug!(
|
||||
session_id = self.session_id(),
|
||||
room_id = ?self.room_id,
|
||||
@@ -347,7 +347,7 @@ impl OutboundGroupSession {
|
||||
}
|
||||
} else {
|
||||
let request_ids: Vec<String> =
|
||||
self.to_share_with_set.read().unwrap().keys().map(|k| k.to_string()).collect();
|
||||
self.to_share_with_set.read().keys().map(|k| k.to_string()).collect();
|
||||
|
||||
error!(
|
||||
all_request_ids = ?request_ids,
|
||||
@@ -540,22 +540,21 @@ impl OutboundGroupSession {
|
||||
/// Has or will the session be shared with the given user/device pair.
|
||||
pub(crate) fn is_shared_with(&self, device: &DeviceData) -> ShareState {
|
||||
// Check if we shared the session.
|
||||
let shared_state =
|
||||
self.shared_with_set.read().unwrap().get(device.user_id()).and_then(|d| {
|
||||
d.get(device.device_id()).map(|s| match s {
|
||||
ShareInfo::Shared(s) => {
|
||||
if device.curve25519_key() == Some(s.sender_key) {
|
||||
ShareState::Shared {
|
||||
message_index: s.message_index,
|
||||
olm_wedging_index: s.olm_wedging_index,
|
||||
}
|
||||
} else {
|
||||
ShareState::SharedButChangedSenderKey
|
||||
let shared_state = self.shared_with_set.read().get(device.user_id()).and_then(|d| {
|
||||
d.get(device.device_id()).map(|s| match s {
|
||||
ShareInfo::Shared(s) => {
|
||||
if device.curve25519_key() == Some(s.sender_key) {
|
||||
ShareState::Shared {
|
||||
message_index: s.message_index,
|
||||
olm_wedging_index: s.olm_wedging_index,
|
||||
}
|
||||
} else {
|
||||
ShareState::SharedButChangedSenderKey
|
||||
}
|
||||
ShareInfo::Withheld(_) => ShareState::NotShared,
|
||||
})
|
||||
});
|
||||
}
|
||||
ShareInfo::Withheld(_) => ShareState::NotShared,
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(state) = shared_state {
|
||||
state
|
||||
@@ -565,24 +564,23 @@ impl OutboundGroupSession {
|
||||
|
||||
// Find the first request that contains the given user id and
|
||||
// device ID.
|
||||
let shared =
|
||||
self.to_share_with_set.read().unwrap().values().find_map(|(_, share_info)| {
|
||||
let d = share_info.get(device.user_id())?;
|
||||
let info = d.get(device.device_id())?;
|
||||
Some(match info {
|
||||
ShareInfo::Shared(info) => {
|
||||
if device.curve25519_key() == Some(info.sender_key) {
|
||||
ShareState::Shared {
|
||||
message_index: info.message_index,
|
||||
olm_wedging_index: info.olm_wedging_index,
|
||||
}
|
||||
} else {
|
||||
ShareState::SharedButChangedSenderKey
|
||||
let shared = self.to_share_with_set.read().values().find_map(|(_, share_info)| {
|
||||
let d = share_info.get(device.user_id())?;
|
||||
let info = d.get(device.device_id())?;
|
||||
Some(match info {
|
||||
ShareInfo::Shared(info) => {
|
||||
if device.curve25519_key() == Some(info.sender_key) {
|
||||
ShareState::Shared {
|
||||
message_index: info.message_index,
|
||||
olm_wedging_index: info.olm_wedging_index,
|
||||
}
|
||||
} else {
|
||||
ShareState::SharedButChangedSenderKey
|
||||
}
|
||||
ShareInfo::Withheld(_) => ShareState::NotShared,
|
||||
})
|
||||
});
|
||||
}
|
||||
ShareInfo::Withheld(_) => ShareState::NotShared,
|
||||
})
|
||||
});
|
||||
|
||||
shared.unwrap_or(ShareState::NotShared)
|
||||
}
|
||||
@@ -591,7 +589,6 @@ impl OutboundGroupSession {
|
||||
pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
|
||||
self.shared_with_set
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(device.user_id())
|
||||
.and_then(|d| {
|
||||
let info = d.get(device.device_id())?;
|
||||
@@ -603,7 +600,7 @@ impl OutboundGroupSession {
|
||||
|
||||
// Find the first request that contains the given user id and
|
||||
// device ID.
|
||||
self.to_share_with_set.read().unwrap().values().any(|(_, share_info)| {
|
||||
self.to_share_with_set.read().values().any(|(_, share_info)| {
|
||||
share_info
|
||||
.get(device.user_id())
|
||||
.and_then(|d| d.get(device.device_id()))
|
||||
@@ -622,7 +619,7 @@ impl OutboundGroupSession {
|
||||
sender_key: Curve25519PublicKey,
|
||||
index: u32,
|
||||
) {
|
||||
self.shared_with_set.write().unwrap().entry(user_id.to_owned()).or_default().insert(
|
||||
self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert(
|
||||
device_id.to_owned(),
|
||||
ShareInfo::new_shared(sender_key, index, Default::default()),
|
||||
);
|
||||
@@ -641,7 +638,6 @@ impl OutboundGroupSession {
|
||||
ShareInfo::new_shared(sender_key, self.message_index().await, Default::default());
|
||||
self.shared_with_set
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(user_id.to_owned())
|
||||
.or_default()
|
||||
.insert(device_id.to_owned(), share_info);
|
||||
@@ -650,12 +646,12 @@ impl OutboundGroupSession {
|
||||
/// Get the list of requests that need to be sent out for this session to be
|
||||
/// marked as shared.
|
||||
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
|
||||
self.to_share_with_set.read().unwrap().values().map(|(req, _)| req.clone()).collect()
|
||||
self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect()
|
||||
}
|
||||
|
||||
/// Get the list of request ids this session is waiting for to be sent out.
|
||||
pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
|
||||
self.to_share_with_set.read().unwrap().keys().cloned().collect()
|
||||
self.to_share_with_set.read().keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Restore a Session from a previously pickled string.
|
||||
@@ -717,8 +713,8 @@ impl OutboundGroupSession {
|
||||
message_count: self.message_count.load(Ordering::SeqCst),
|
||||
shared: self.shared(),
|
||||
invalidated: self.invalidated(),
|
||||
shared_with_set: self.shared_with_set.read().unwrap().clone(),
|
||||
requests: self.to_share_with_set.read().unwrap().clone(),
|
||||
shared_with_set: self.shared_with_set.read().clone(),
|
||||
requests: self.to_share_with_set.read().clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,12 +17,14 @@ mod share_strategy;
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
fmt::Debug,
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use futures_util::future::join_all;
|
||||
use itertools::Itertools;
|
||||
use matrix_sdk_common::{deserialized_responses::WithheldCode, executor::spawn};
|
||||
use matrix_sdk_common::{
|
||||
deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
|
||||
};
|
||||
use ruma::{
|
||||
events::{AnyMessageLikeEventContent, ToDeviceEventType},
|
||||
serde::Raw,
|
||||
@@ -60,7 +62,7 @@ impl GroupSessionCache {
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&self, session: OutboundGroupSession) {
|
||||
self.sessions.write().unwrap().insert(session.room_id().to_owned(), session);
|
||||
self.sessions.write().insert(session.room_id().to_owned(), session);
|
||||
}
|
||||
|
||||
/// Either get a session for the given room from the cache or load it from
|
||||
@@ -72,20 +74,20 @@ impl GroupSessionCache {
|
||||
pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
||||
// Get the cached session, if there isn't one load one from the store
|
||||
// and put it in the cache.
|
||||
if let Some(s) = self.sessions.read().unwrap().get(room_id) {
|
||||
if let Some(s) = self.sessions.read().get(room_id) {
|
||||
return Some(s.clone());
|
||||
}
|
||||
|
||||
match self.store.get_outbound_group_session(room_id).await {
|
||||
Ok(Some(s)) => {
|
||||
{
|
||||
let mut sessions_being_shared = self.sessions_being_shared.write().unwrap();
|
||||
let mut sessions_being_shared = self.sessions_being_shared.write();
|
||||
for request_id in s.pending_request_ids() {
|
||||
sessions_being_shared.insert(request_id, s.clone());
|
||||
}
|
||||
}
|
||||
|
||||
self.sessions.write().unwrap().insert(room_id.to_owned(), s.clone());
|
||||
self.sessions.write().insert(room_id.to_owned(), s.clone());
|
||||
|
||||
Some(s)
|
||||
}
|
||||
@@ -104,20 +106,20 @@ impl GroupSessionCache {
|
||||
/// * `room_id` - The id of the room for which we should get the outbound
|
||||
/// group session.
|
||||
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
|
||||
self.sessions.read().unwrap().get(room_id).cloned()
|
||||
self.sessions.read().get(room_id).cloned()
|
||||
}
|
||||
|
||||
/// Returns whether any session is withheld with the given device and code.
|
||||
fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
|
||||
self.sessions.read().unwrap().values().any(|s| s.is_withheld_to(device, code))
|
||||
self.sessions.read().values().any(|s| s.is_withheld_to(device, code))
|
||||
}
|
||||
|
||||
fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
|
||||
self.sessions_being_shared.write().unwrap().remove(id)
|
||||
self.sessions_being_shared.write().remove(id)
|
||||
}
|
||||
|
||||
fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
|
||||
self.sessions_being_shared.write().unwrap().insert(id, session);
|
||||
self.sessions_being_shared.write().insert(id, session);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ pub(crate) async fn collect_session_recipients(
|
||||
trace!(?users, ?settings, "Calculating group session recipients");
|
||||
|
||||
let users_shared_with: BTreeSet<OwnedUserId> =
|
||||
outbound.shared_with_set.read().unwrap().keys().cloned().collect();
|
||||
outbound.shared_with_set.read().keys().cloned().collect();
|
||||
|
||||
let users_shared_with: BTreeSet<&UserId> = users_shared_with.iter().map(Deref::deref).collect();
|
||||
|
||||
@@ -339,7 +339,7 @@ fn is_session_overshared_for_user(
|
||||
let recipient_device_ids: BTreeSet<&DeviceId> =
|
||||
recipient_devices.iter().map(|d| d.device_id()).collect();
|
||||
|
||||
let guard = outbound_session.shared_with_set.read().unwrap();
|
||||
let guard = outbound_session.shared_with_set.read();
|
||||
|
||||
let Some(shared) = guard.get(user_id) else {
|
||||
return false;
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use matrix_sdk_common::failures_cache::FailuresCache;
|
||||
use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
|
||||
use ruma::{
|
||||
api::client::keys::claim_keys::v3::{
|
||||
Request as KeysClaimRequest, Response as KeysClaimResponse,
|
||||
@@ -98,7 +98,7 @@ impl SessionManager {
|
||||
|
||||
/// Mark the outgoing request as sent.
|
||||
pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
|
||||
self.outgoing_to_device_requests.write().unwrap().remove(id);
|
||||
self.outgoing_to_device_requests.write().remove(id);
|
||||
}
|
||||
|
||||
pub async fn mark_device_as_wedged(
|
||||
@@ -121,13 +121,11 @@ impl SessionManager {
|
||||
if should_unwedge {
|
||||
self.users_for_key_claim
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(device.user_id().to_owned())
|
||||
.or_default()
|
||||
.insert(device.device_id().into());
|
||||
self.wedged_devices
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(device.user_id().to_owned())
|
||||
.or_default()
|
||||
.insert(device.device_id().into());
|
||||
@@ -142,7 +140,6 @@ impl SessionManager {
|
||||
pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
|
||||
self.wedged_devices
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(device.user_id())
|
||||
.is_some_and(|d| d.contains(device.device_id()))
|
||||
}
|
||||
@@ -151,13 +148,7 @@ impl SessionManager {
|
||||
///
|
||||
/// If the device was wedged this will queue up a dummy to-device message.
|
||||
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
|
||||
if self
|
||||
.wedged_devices
|
||||
.write()
|
||||
.unwrap()
|
||||
.get_mut(user_id)
|
||||
.is_some_and(|d| d.remove(device_id))
|
||||
{
|
||||
if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id)) {
|
||||
if let Some(device) = self.store.get_device(user_id, device_id).await? {
|
||||
let (_, content) =
|
||||
device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
|
||||
@@ -176,7 +167,6 @@ impl SessionManager {
|
||||
|
||||
self.outgoing_to_device_requests
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(request.request_id.clone(), request);
|
||||
}
|
||||
}
|
||||
@@ -278,7 +268,7 @@ impl SessionManager {
|
||||
|
||||
// Add the list of sessions that for some reason automatically need to
|
||||
// create an Olm session.
|
||||
for (user, device_ids) in self.users_for_key_claim.read().unwrap().iter() {
|
||||
for (user, device_ids) in self.users_for_key_claim.read().iter() {
|
||||
missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
|
||||
device_ids
|
||||
.iter()
|
||||
@@ -319,12 +309,12 @@ impl SessionManager {
|
||||
|
||||
// stash the details of the request so that we can refer to it when handling the
|
||||
// response
|
||||
*(self.current_key_claim_request.write().unwrap()) = result.clone();
|
||||
*(self.current_key_claim_request.write()) = result.clone();
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
|
||||
self.failed_devices.read().unwrap().get(user_id).is_some_and(|d| d.contains(device_id))
|
||||
self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
|
||||
}
|
||||
|
||||
/// This method will try to figure out for which devices a one-time key was
|
||||
@@ -354,7 +344,7 @@ impl SessionManager {
|
||||
) {
|
||||
// First check that the response is for the request we were expecting.
|
||||
let request = {
|
||||
let mut guard = self.current_key_claim_request.write().unwrap();
|
||||
let mut guard = self.current_key_claim_request.write();
|
||||
let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
|
||||
|
||||
if Some(request_id) == expected_request_id {
|
||||
@@ -416,7 +406,7 @@ impl SessionManager {
|
||||
"Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
|
||||
);
|
||||
|
||||
let mut failed_devices_lock = self.failed_devices.write().unwrap();
|
||||
let mut failed_devices_lock = self.failed_devices.write();
|
||||
|
||||
for (user_id, device_set) in missing_devices_by_user {
|
||||
failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
|
||||
@@ -542,7 +532,6 @@ impl SessionManager {
|
||||
|
||||
self.failed_devices
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(user_id.to_owned())
|
||||
.or_default()
|
||||
.insert(device_id.to_owned());
|
||||
@@ -573,7 +562,7 @@ impl SessionManager {
|
||||
info!(sessions = ?new_sessions, "Established new Olm sessions");
|
||||
|
||||
for (user, device_map) in new_sessions {
|
||||
if let Some(user_cache) = self.failed_devices.read().unwrap().get(user) {
|
||||
if let Some(user_cache) = self.failed_devices.read().get(user) {
|
||||
user_cache.remove(device_map.into_keys());
|
||||
}
|
||||
}
|
||||
@@ -597,14 +586,9 @@ impl SessionManager {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
iter,
|
||||
ops::Deref,
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
time::Duration,
|
||||
};
|
||||
use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
|
||||
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use matrix_sdk_test::{async_test, ruma_response_from_json};
|
||||
use ruma::{
|
||||
api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
|
||||
@@ -861,11 +845,11 @@ mod tests {
|
||||
|
||||
let curve_key = bob_device.curve25519_key().unwrap();
|
||||
|
||||
assert!(!manager.users_for_key_claim.read().unwrap().contains_key(bob.user_id()));
|
||||
assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
|
||||
assert!(!manager.is_device_wedged(&bob_device));
|
||||
manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
|
||||
assert!(manager.is_device_wedged(&bob_device));
|
||||
assert!(manager.users_for_key_claim.read().unwrap().contains_key(bob.user_id()));
|
||||
assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
|
||||
|
||||
let (txn_id, request) =
|
||||
manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
|
||||
@@ -885,13 +869,13 @@ mod tests {
|
||||
|
||||
let response = KeyClaimResponse::new(one_time_keys);
|
||||
|
||||
assert!(manager.outgoing_to_device_requests.read().unwrap().is_empty());
|
||||
assert!(manager.outgoing_to_device_requests.read().is_empty());
|
||||
|
||||
manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
|
||||
|
||||
assert!(!manager.is_device_wedged(&bob_device));
|
||||
assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
|
||||
assert!(!manager.outgoing_to_device_requests.read().unwrap().is_empty())
|
||||
assert!(!manager.outgoing_to_device_requests.read().is_empty())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
@@ -1012,7 +996,6 @@ mod tests {
|
||||
manager
|
||||
.failed_devices
|
||||
.write()
|
||||
.unwrap()
|
||||
.get(alice)
|
||||
.unwrap()
|
||||
.expire(&alice_account.device_id().to_owned());
|
||||
@@ -1027,7 +1010,6 @@ mod tests {
|
||||
assert!(manager
|
||||
.failed_devices
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(alice)
|
||||
.unwrap()
|
||||
.failure_count(alice_account.device_id())
|
||||
|
||||
@@ -22,10 +22,11 @@ use std::{
|
||||
fmt::Display,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, RwLock as StdRwLock, Weak,
|
||||
Arc, Weak,
|
||||
},
|
||||
};
|
||||
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
@@ -104,7 +105,6 @@ impl GroupSessionStore {
|
||||
pub fn add(&self, session: InboundGroupSession) -> bool {
|
||||
self.entries
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(session.room_id().to_owned())
|
||||
.or_default()
|
||||
.insert(session.session_id().to_owned(), session)
|
||||
@@ -113,12 +113,12 @@ impl GroupSessionStore {
|
||||
|
||||
/// Get all the group sessions the store knows about.
|
||||
pub fn get_all(&self) -> Vec<InboundGroupSession> {
|
||||
self.entries.read().unwrap().values().flat_map(HashMap::values).cloned().collect()
|
||||
self.entries.read().values().flat_map(HashMap::values).cloned().collect()
|
||||
}
|
||||
|
||||
/// Get the number of `InboundGroupSession`s we have.
|
||||
pub fn count(&self) -> usize {
|
||||
self.entries.read().unwrap().values().map(HashMap::len).sum()
|
||||
self.entries.read().values().map(HashMap::len).sum()
|
||||
}
|
||||
|
||||
/// Get a inbound group session from our store.
|
||||
@@ -128,7 +128,7 @@ impl GroupSessionStore {
|
||||
///
|
||||
/// * `session_id` - The unique id of the session.
|
||||
pub fn get(&self, room_id: &RoomId, session_id: &str) -> Option<InboundGroupSession> {
|
||||
self.entries.read().unwrap().get(room_id)?.get(session_id).cloned()
|
||||
self.entries.read().get(room_id)?.get(session_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +151,6 @@ impl DeviceStore {
|
||||
let user_id = device.user_id();
|
||||
self.entries
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(user_id.to_owned())
|
||||
.or_default()
|
||||
.insert(device.device_id().into(), device)
|
||||
@@ -160,7 +159,7 @@ impl DeviceStore {
|
||||
|
||||
/// Get the device with the given device_id and belonging to the given user.
|
||||
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
|
||||
Some(self.entries.read().unwrap().get(user_id)?.get(device_id)?.clone())
|
||||
Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
|
||||
}
|
||||
|
||||
/// Remove the device with the given device_id and belonging to the given
|
||||
@@ -168,14 +167,13 @@ impl DeviceStore {
|
||||
///
|
||||
/// Returns the device if it was removed, None if it wasn't in the store.
|
||||
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
|
||||
self.entries.write().unwrap().get_mut(user_id)?.remove(device_id)
|
||||
self.entries.write().get_mut(user_id)?.remove(device_id)
|
||||
}
|
||||
|
||||
/// Get a read-only view over all devices of the given user.
|
||||
pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
|
||||
self.entries
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(user_id.to_owned())
|
||||
.or_default()
|
||||
.iter()
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
convert::Infallible,
|
||||
sync::RwLock as StdRwLock,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use matrix_sdk_common::store_locks::memory_store_helper::try_take_leased_lock;
|
||||
use matrix_sdk_common::{
|
||||
locks::RwLock as StdRwLock, store_locks::memory_store_helper::try_take_leased_lock,
|
||||
};
|
||||
use ruma::{
|
||||
events::secret::request::SecretName, time::Instant, DeviceId, OwnedDeviceId, OwnedRoomId,
|
||||
OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
|
||||
@@ -143,7 +144,7 @@ impl MemoryStore {
|
||||
}
|
||||
|
||||
fn save_sessions(&self, sessions: Vec<Session>) {
|
||||
let mut session_store = self.sessions.write().unwrap();
|
||||
let mut session_store = self.sessions.write();
|
||||
|
||||
for session in sessions {
|
||||
let entry = session_store.entry(session.sender_key().to_base64()).or_default();
|
||||
@@ -159,12 +160,11 @@ impl MemoryStore {
|
||||
fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
|
||||
self.outbound_group_sessions
|
||||
.write()
|
||||
.unwrap()
|
||||
.extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
|
||||
}
|
||||
|
||||
fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
|
||||
*self.private_identity.write().unwrap() = private_identity;
|
||||
*self.private_identity.write() = private_identity;
|
||||
}
|
||||
|
||||
/// Return all the [`InboundGroupSession`]s we have, paired with the
|
||||
@@ -176,7 +176,6 @@ impl MemoryStore {
|
||||
let lookup = |s: &InboundGroupSession| {
|
||||
self.inbound_group_sessions_backed_up_to
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&s.room_id)?
|
||||
.get(s.session_id())
|
||||
.cloned()
|
||||
@@ -202,11 +201,11 @@ impl CryptoStore for MemoryStore {
|
||||
type Error = Infallible;
|
||||
|
||||
async fn load_account(&self) -> Result<Option<Account>> {
|
||||
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone()))
|
||||
Ok(self.account.read().as_ref().map(|acc| acc.deep_clone()))
|
||||
}
|
||||
|
||||
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
|
||||
Ok(self.private_identity.read().unwrap().clone())
|
||||
Ok(self.private_identity.read().clone())
|
||||
}
|
||||
|
||||
async fn next_batch_token(&self) -> Result<Option<String>> {
|
||||
@@ -215,7 +214,7 @@ impl CryptoStore for MemoryStore {
|
||||
|
||||
async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
|
||||
if let Some(account) = changes.account {
|
||||
*self.account.write().unwrap() = Some(account);
|
||||
*self.account.write() = Some(account);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -232,7 +231,7 @@ impl CryptoStore for MemoryStore {
|
||||
self.delete_devices(changes.devices.deleted);
|
||||
|
||||
{
|
||||
let mut identities = self.identities.write().unwrap();
|
||||
let mut identities = self.identities.write();
|
||||
for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
|
||||
identities.insert(
|
||||
identity.user_id().to_owned(),
|
||||
@@ -242,15 +241,15 @@ impl CryptoStore for MemoryStore {
|
||||
}
|
||||
|
||||
{
|
||||
let mut olm_hashes = self.olm_hashes.write().unwrap();
|
||||
let mut olm_hashes = self.olm_hashes.write();
|
||||
for hash in changes.message_hashes {
|
||||
olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut outgoing_key_requests = self.outgoing_key_requests.write().unwrap();
|
||||
let mut key_requests_by_info = self.key_requests_by_info.write().unwrap();
|
||||
let mut outgoing_key_requests = self.outgoing_key_requests.write();
|
||||
let mut key_requests_by_info = self.key_requests_by_info.write();
|
||||
|
||||
for key_request in changes.key_requests {
|
||||
let id = key_request.request_id.clone();
|
||||
@@ -275,14 +274,14 @@ impl CryptoStore for MemoryStore {
|
||||
}
|
||||
|
||||
{
|
||||
let mut secret_inbox = self.secret_inbox.write().unwrap();
|
||||
let mut secret_inbox = self.secret_inbox.write();
|
||||
for secret in changes.secrets {
|
||||
secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut direct_withheld_info = self.direct_withheld_info.write().unwrap();
|
||||
let mut direct_withheld_info = self.direct_withheld_info.write();
|
||||
for (room_id, data) in changes.withheld_session_info {
|
||||
for (session_id, event) in data {
|
||||
direct_withheld_info
|
||||
@@ -298,7 +297,7 @@ impl CryptoStore for MemoryStore {
|
||||
}
|
||||
|
||||
if !changes.room_settings.is_empty() {
|
||||
let mut settings = self.room_settings.write().unwrap();
|
||||
let mut settings = self.room_settings.write();
|
||||
settings.extend(changes.room_settings);
|
||||
}
|
||||
|
||||
@@ -311,7 +310,7 @@ impl CryptoStore for MemoryStore {
|
||||
backed_up_to_version: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let mut inbound_group_sessions_backed_up_to =
|
||||
self.inbound_group_sessions_backed_up_to.write().unwrap();
|
||||
self.inbound_group_sessions_backed_up_to.write();
|
||||
|
||||
for session in sessions {
|
||||
let room_id = session.room_id();
|
||||
@@ -339,7 +338,7 @@ impl CryptoStore for MemoryStore {
|
||||
}
|
||||
|
||||
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
|
||||
Ok(self.sessions.read().unwrap().get(sender_key).cloned())
|
||||
Ok(self.sessions.read().get(sender_key).cloned())
|
||||
}
|
||||
|
||||
async fn get_inbound_group_session(
|
||||
@@ -358,7 +357,6 @@ impl CryptoStore for MemoryStore {
|
||||
Ok(self
|
||||
.direct_withheld_info
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.and_then(|e| Some(e.get(session_id)?.to_owned())))
|
||||
}
|
||||
@@ -458,7 +456,7 @@ impl CryptoStore for MemoryStore {
|
||||
room_and_session_ids: &[(&RoomId, &str)],
|
||||
) -> Result<()> {
|
||||
let mut inbound_group_sessions_backed_up_to =
|
||||
self.inbound_group_sessions_backed_up_to.write().unwrap();
|
||||
self.inbound_group_sessions_backed_up_to.write();
|
||||
|
||||
for &(room_id, session_id) in room_and_session_ids {
|
||||
let session = self.inbound_group_sessions.get(room_id, session_id);
|
||||
@@ -506,15 +504,15 @@ impl CryptoStore for MemoryStore {
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<OutboundGroupSession>> {
|
||||
Ok(self.outbound_group_sessions.read().unwrap().get(room_id).cloned())
|
||||
Ok(self.outbound_group_sessions.read().get(room_id).cloned())
|
||||
}
|
||||
|
||||
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
|
||||
Ok(self.tracked_users.read().unwrap().values().cloned().collect())
|
||||
Ok(self.tracked_users.read().values().cloned().collect())
|
||||
}
|
||||
|
||||
async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
|
||||
self.tracked_users.write().unwrap().extend(tracked_users.iter().map(|(user_id, dirty)| {
|
||||
self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
|
||||
let user_id: OwnedUserId = user_id.to_owned().into();
|
||||
(user_id.clone(), TrackedUser { user_id, dirty: *dirty })
|
||||
}));
|
||||
@@ -543,7 +541,7 @@ impl CryptoStore for MemoryStore {
|
||||
}
|
||||
|
||||
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
|
||||
let serialized = self.identities.read().unwrap().get(user_id).cloned();
|
||||
let serialized = self.identities.read().get(user_id).cloned();
|
||||
match serialized {
|
||||
None => Ok(None),
|
||||
Some(serialized) => {
|
||||
@@ -557,7 +555,6 @@ impl CryptoStore for MemoryStore {
|
||||
Ok(self
|
||||
.olm_hashes
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(message_hash.sender_key.to_owned())
|
||||
.or_default()
|
||||
.contains(&message_hash.hash))
|
||||
@@ -567,7 +564,7 @@ impl CryptoStore for MemoryStore {
|
||||
&self,
|
||||
request_id: &TransactionId,
|
||||
) -> Result<Option<GossipRequest>> {
|
||||
Ok(self.outgoing_key_requests.read().unwrap().get(request_id).cloned())
|
||||
Ok(self.outgoing_key_requests.read().get(request_id).cloned())
|
||||
}
|
||||
|
||||
async fn get_secret_request_by_info(
|
||||
@@ -579,16 +576,14 @@ impl CryptoStore for MemoryStore {
|
||||
Ok(self
|
||||
.key_requests_by_info
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&key_info_string)
|
||||
.and_then(|i| self.outgoing_key_requests.read().unwrap().get(i).cloned()))
|
||||
.and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
|
||||
}
|
||||
|
||||
async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
|
||||
Ok(self
|
||||
.outgoing_key_requests
|
||||
.read()
|
||||
.unwrap()
|
||||
.values()
|
||||
.filter(|req| !req.sent_out)
|
||||
.cloned()
|
||||
@@ -596,10 +591,10 @@ impl CryptoStore for MemoryStore {
|
||||
}
|
||||
|
||||
async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
|
||||
let req = self.outgoing_key_requests.write().unwrap().remove(request_id);
|
||||
let req = self.outgoing_key_requests.write().remove(request_id);
|
||||
if let Some(i) = req {
|
||||
let key_info_string = encode_key_info(&i.info);
|
||||
self.key_requests_by_info.write().unwrap().remove(&key_info_string);
|
||||
self.key_requests_by_info.write().remove(&key_info_string);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -609,36 +604,30 @@ impl CryptoStore for MemoryStore {
|
||||
&self,
|
||||
secret_name: &SecretName,
|
||||
) -> Result<Vec<GossippedSecret>> {
|
||||
Ok(self
|
||||
.secret_inbox
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(secret_name.to_string())
|
||||
.or_default()
|
||||
.to_owned())
|
||||
Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
|
||||
}
|
||||
|
||||
async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
|
||||
self.secret_inbox.write().unwrap().remove(secret_name.as_str());
|
||||
self.secret_inbox.write().remove(secret_name.as_str());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
|
||||
Ok(self.room_settings.read().unwrap().get(room_id).cloned())
|
||||
Ok(self.room_settings.read().get(room_id).cloned())
|
||||
}
|
||||
|
||||
async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.custom_values.read().unwrap().get(key).cloned())
|
||||
Ok(self.custom_values.read().get(key).cloned())
|
||||
}
|
||||
|
||||
async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
|
||||
self.custom_values.write().unwrap().insert(key.to_owned(), value);
|
||||
self.custom_values.write().insert(key.to_owned(), value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_custom_value(&self, key: &str) -> Result<()> {
|
||||
self.custom_values.write().unwrap().remove(key);
|
||||
self.custom_values.write().remove(key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -648,7 +637,7 @@ impl CryptoStore for MemoryStore {
|
||||
key: &str,
|
||||
holder: &str,
|
||||
) -> Result<bool> {
|
||||
Ok(try_take_leased_lock(&mut self.leases.write().unwrap(), lease_duration_ms, key, holder))
|
||||
Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1167,7 +1156,7 @@ mod integration_tests {
|
||||
|
||||
impl MemoryStore {
|
||||
fn get_static_account(&self) -> Option<StaticAccountData> {
|
||||
self.account.read().unwrap().as_ref().map(|acc| acc.static_data().clone())
|
||||
self.account.read().as_ref().map(|acc| acc.static_data().clone())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,13 +43,14 @@ use std::{
|
||||
fmt::Debug,
|
||||
ops::Deref,
|
||||
pin::pin,
|
||||
sync::{atomic::Ordering, Arc, RwLock as StdRwLock},
|
||||
sync::{atomic::Ordering, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use as_variant::as_variant;
|
||||
use futures_core::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use ruma::{
|
||||
encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
|
||||
OwnedRoomId, OwnedUserId, UserId,
|
||||
@@ -155,7 +156,7 @@ impl KeyQueryManager {
|
||||
let tracked_users = cache.store.load_tracked_users().await?;
|
||||
|
||||
let mut query_users_lock = self.users_for_key_query.lock().await;
|
||||
let mut tracked_users_cache = cache.tracked_users.write().unwrap();
|
||||
let mut tracked_users_cache = cache.tracked_users.write();
|
||||
for user in tracked_users {
|
||||
tracked_users_cache.insert(user.user_id.to_owned());
|
||||
|
||||
@@ -245,7 +246,7 @@ impl SyncedKeyQueryManager<'_> {
|
||||
let mut key_query_lock = self.manager.users_for_key_query.lock().await;
|
||||
|
||||
{
|
||||
let mut tracked_users = self.cache.tracked_users.write().unwrap();
|
||||
let mut tracked_users = self.cache.tracked_users.write();
|
||||
for user_id in users {
|
||||
if tracked_users.insert(user_id.to_owned()) {
|
||||
key_query_lock.insert_user(user_id);
|
||||
@@ -271,7 +272,7 @@ impl SyncedKeyQueryManager<'_> {
|
||||
let mut key_query_lock = self.manager.users_for_key_query.lock().await;
|
||||
|
||||
{
|
||||
let tracked_users = &self.cache.tracked_users.read().unwrap();
|
||||
let tracked_users = &self.cache.tracked_users.read();
|
||||
for user_id in users {
|
||||
if tracked_users.contains(user_id) {
|
||||
key_query_lock.insert_user(user_id);
|
||||
@@ -297,7 +298,7 @@ impl SyncedKeyQueryManager<'_> {
|
||||
let mut key_query_lock = self.manager.users_for_key_query.lock().await;
|
||||
|
||||
{
|
||||
let tracked_users = self.cache.tracked_users.read().unwrap();
|
||||
let tracked_users = self.cache.tracked_users.read();
|
||||
for user_id in users {
|
||||
if tracked_users.contains(user_id) {
|
||||
let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
|
||||
@@ -330,7 +331,7 @@ impl SyncedKeyQueryManager<'_> {
|
||||
|
||||
/// See the docs for [`crate::OlmMachine::tracked_users()`].
|
||||
pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
|
||||
self.cache.tracked_users.read().unwrap().iter().cloned().collect()
|
||||
self.cache.tracked_users.read().iter().cloned().collect()
|
||||
}
|
||||
|
||||
/// Mark the given user as being tracked for device lists, and mark that it
|
||||
@@ -340,7 +341,7 @@ impl SyncedKeyQueryManager<'_> {
|
||||
/// next time [`Store::users_for_key_query()`] is called.
|
||||
pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
|
||||
self.manager.users_for_key_query.lock().await.insert_user(user);
|
||||
self.cache.tracked_users.write().unwrap().insert(user.to_owned());
|
||||
self.cache.tracked_users.write().insert(user.to_owned());
|
||||
|
||||
self.cache.store.save_tracked_users(&[(user, true)]).await
|
||||
}
|
||||
|
||||
@@ -12,12 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
};
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use as_variant::as_variant;
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId};
|
||||
#[cfg(feature = "qrcode")]
|
||||
use tracing::debug;
|
||||
@@ -69,7 +67,7 @@ impl VerificationCache {
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.inner.verification.read().unwrap().values().all(|m| m.is_empty())
|
||||
self.inner.verification.read().values().all(|m| m.is_empty())
|
||||
}
|
||||
|
||||
/// Add a new `Verification` object to the cache, this will cancel any
|
||||
@@ -78,7 +76,7 @@ impl VerificationCache {
|
||||
pub fn insert(&self, verification: impl Into<Verification>) {
|
||||
let verification = verification.into();
|
||||
|
||||
let mut verification_write_guard = self.inner.verification.write().unwrap();
|
||||
let mut verification_write_guard = self.inner.verification.write();
|
||||
let user_verifications =
|
||||
verification_write_guard.entry(verification.other_user().to_owned()).or_default();
|
||||
|
||||
@@ -150,22 +148,21 @@ impl VerificationCache {
|
||||
self.inner
|
||||
.verification
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(verification.other_user().to_owned())
|
||||
.or_default()
|
||||
.insert(verification.flow_id().to_owned(), verification.clone());
|
||||
}
|
||||
|
||||
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
|
||||
self.inner.verification.read().unwrap().get(sender)?.get(flow_id).cloned()
|
||||
self.inner.verification.read().get(sender)?.get(flow_id).cloned()
|
||||
}
|
||||
|
||||
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
|
||||
self.inner.outgoing_requests.read().unwrap().values().cloned().collect()
|
||||
self.inner.outgoing_requests.read().values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
|
||||
let verification = &mut self.inner.verification.write().unwrap();
|
||||
let verification = &mut self.inner.verification.write();
|
||||
|
||||
for user_verification in verification.values_mut() {
|
||||
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
|
||||
@@ -186,7 +183,7 @@ impl VerificationCache {
|
||||
|
||||
pub fn add_request(&self, request: OutgoingRequest) {
|
||||
trace!("Adding an outgoing request {:?}", request);
|
||||
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
|
||||
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
|
||||
}
|
||||
|
||||
pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
|
||||
@@ -211,7 +208,7 @@ impl VerificationCache {
|
||||
"Storing the request info, waiting for the request to be marked as sent"
|
||||
);
|
||||
|
||||
self.inner.flow_ids_waiting_for_response.write().unwrap().insert(
|
||||
self.inner.flow_ids_waiting_for_response.write().insert(
|
||||
request_info.request_id.to_owned(),
|
||||
(recipient.to_owned(), request_info.flow_id),
|
||||
);
|
||||
@@ -235,7 +232,7 @@ impl VerificationCache {
|
||||
request: Arc::new(request.into()),
|
||||
};
|
||||
|
||||
self.inner.outgoing_requests.write().unwrap().insert(request_id, request);
|
||||
self.inner.outgoing_requests.write().insert(request_id, request);
|
||||
}
|
||||
|
||||
OutgoingContent::Room(r, c) => {
|
||||
@@ -247,18 +244,18 @@ impl VerificationCache {
|
||||
request_id: request_id.clone(),
|
||||
};
|
||||
|
||||
self.inner.outgoing_requests.write().unwrap().insert(request_id, request);
|
||||
self.inner.outgoing_requests.write().insert(request_id, request);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mark_request_as_sent(&self, request_id: &TransactionId) {
|
||||
if let Some(request_id) = self.inner.outgoing_requests.write().unwrap().remove(request_id) {
|
||||
if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) {
|
||||
trace!(?request_id, "Marking a verification HTTP request as sent");
|
||||
}
|
||||
|
||||
if let Some((user_id, flow_id)) =
|
||||
self.inner.flow_ids_waiting_for_response.read().unwrap().get(request_id)
|
||||
self.inner.flow_ids_waiting_for_response.read().get(request_id)
|
||||
{
|
||||
if let Some(verification) = self.get(user_id, flow_id.as_str()) {
|
||||
match verification {
|
||||
|
||||
@@ -12,11 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock as StdRwLock},
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use matrix_sdk_common::locks::RwLock as StdRwLock;
|
||||
use ruma::{
|
||||
events::{
|
||||
key::verification::VerificationMethod, AnyToDeviceEvent, AnyToDeviceEventContent,
|
||||
@@ -153,13 +151,12 @@ impl VerificationMachine {
|
||||
user_id: &UserId,
|
||||
flow_id: impl AsRef<str>,
|
||||
) -> Option<VerificationRequest> {
|
||||
self.requests.read().unwrap().get(user_id)?.get(flow_id.as_ref()).cloned()
|
||||
self.requests.read().get(user_id)?.get(flow_id.as_ref()).cloned()
|
||||
}
|
||||
|
||||
pub fn get_requests(&self, user_id: &UserId) -> Vec<VerificationRequest> {
|
||||
self.requests
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(user_id)
|
||||
.map(|v| v.iter().map(|(_, value)| value.clone()).collect())
|
||||
.unwrap_or_default()
|
||||
@@ -174,7 +171,7 @@ impl VerificationMachine {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut requests = self.requests.write().unwrap();
|
||||
let mut requests = self.requests.write();
|
||||
let user_requests = requests.entry(request.other_user().to_owned()).or_default();
|
||||
|
||||
// Cancel all the old verifications requests as well as the new one we
|
||||
@@ -247,7 +244,7 @@ impl VerificationMachine {
|
||||
let mut events = vec![];
|
||||
|
||||
let mut requests: Vec<OutgoingVerificationRequest> = {
|
||||
let mut requests = self.requests.write().unwrap();
|
||||
let mut requests = self.requests.write();
|
||||
|
||||
for user_verification in requests.values_mut() {
|
||||
user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled()));
|
||||
|
||||
@@ -12,12 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
matches,
|
||||
sync::{Arc, Mutex},
|
||||
time::Duration,
|
||||
};
|
||||
use std::{matches, sync::Arc, time::Duration};
|
||||
|
||||
use matrix_sdk_common::locks::Mutex;
|
||||
use ruma::{
|
||||
events::{
|
||||
key::verification::{
|
||||
@@ -356,7 +353,7 @@ impl<S: Clone> SasState<S> {
|
||||
let their_public_key = Curve25519PublicKey::from_slice(content.public_key().as_bytes())
|
||||
.map_err(|_| CancelCode::from("Invalid public key"))?;
|
||||
|
||||
if let Some(sas) = self.inner.lock().unwrap().take() {
|
||||
if let Some(sas) = self.inner.lock().take() {
|
||||
sas.diffie_hellman(their_public_key).map_err(|_| "Invalid public key".into())
|
||||
} else {
|
||||
Err(CancelCode::UnexpectedMessage)
|
||||
@@ -1127,7 +1124,7 @@ impl SasState<KeysExchanged> {
|
||||
/// second element the English description of the emoji.
|
||||
pub fn get_emoji(&self) -> [Emoji; 7] {
|
||||
get_emoji(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
self.state.we_started,
|
||||
@@ -1140,7 +1137,7 @@ impl SasState<KeysExchanged> {
|
||||
/// numbers can be converted to a unique emoji defined by the spec.
|
||||
pub fn get_emoji_index(&self) -> [u8; 7] {
|
||||
get_emoji_index(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
self.state.we_started,
|
||||
@@ -1153,7 +1150,7 @@ impl SasState<KeysExchanged> {
|
||||
/// the short auth string.
|
||||
pub fn get_decimal(&self) -> (u16, u16, u16) {
|
||||
get_decimal(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
self.state.we_started,
|
||||
@@ -1175,7 +1172,7 @@ impl SasState<KeysExchanged> {
|
||||
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
|
||||
|
||||
let (devices, master_keys) = receive_mac_event(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
sender,
|
||||
@@ -1239,7 +1236,7 @@ impl SasState<Confirmed> {
|
||||
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
|
||||
|
||||
let (devices, master_keys) = receive_mac_event(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
sender,
|
||||
@@ -1283,7 +1280,7 @@ impl SasState<Confirmed> {
|
||||
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
|
||||
|
||||
let (devices, master_keys) = receive_mac_event(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
sender,
|
||||
@@ -1315,7 +1312,7 @@ impl SasState<Confirmed> {
|
||||
/// The content needs to be automatically sent to the other side.
|
||||
pub fn as_content(&self) -> OutgoingContent {
|
||||
get_mac_content(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
&self.verification_flow_id,
|
||||
self.state.accepted_protocols.message_auth_code,
|
||||
@@ -1376,7 +1373,7 @@ impl SasState<MacReceived> {
|
||||
/// second element the English description of the emoji.
|
||||
pub fn get_emoji(&self) -> [Emoji; 7] {
|
||||
get_emoji(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
self.state.we_started,
|
||||
@@ -1389,7 +1386,7 @@ impl SasState<MacReceived> {
|
||||
/// numbers can be converted to a unique emoji defined by the spec.
|
||||
pub fn get_emoji_index(&self) -> [u8; 7] {
|
||||
get_emoji_index(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
self.state.we_started,
|
||||
@@ -1402,7 +1399,7 @@ impl SasState<MacReceived> {
|
||||
/// the short auth string.
|
||||
pub fn get_decimal(&self) -> (u16, u16, u16) {
|
||||
get_decimal(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
self.verification_flow_id.as_str(),
|
||||
self.state.we_started,
|
||||
@@ -1417,7 +1414,7 @@ impl SasState<WaitingForDone> {
|
||||
/// wasn't already sent.
|
||||
pub fn as_content(&self) -> OutgoingContent {
|
||||
get_mac_content(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
&self.verification_flow_id,
|
||||
self.state.accepted_protocols.message_auth_code,
|
||||
@@ -1480,7 +1477,7 @@ impl SasState<Done> {
|
||||
/// wasn't already sent.
|
||||
pub fn as_content(&self) -> OutgoingContent {
|
||||
get_mac_content(
|
||||
&self.state.sas.lock().unwrap(),
|
||||
&self.state.sas.lock(),
|
||||
&self.ids,
|
||||
&self.verification_flow_id,
|
||||
self.state.accepted_protocols.message_auth_code,
|
||||
|
||||
Reference in New Issue
Block a user