refactor(crypto): Use the simplified locks across the crypto crate

This commit is contained in:
Damir Jelić
2025-01-08 15:51:01 +01:00
parent 46dc2a9c5e
commit 62567ca6eb
16 changed files with 213 additions and 258 deletions

View File

@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use matrix_sdk_common::locks::Mutex;
use ruma::{
api::client::backup::{EncryptedSessionDataInit, KeyBackupData, KeyBackupDataInit},
serde::Base64,
@@ -87,7 +88,7 @@ impl MegolmV1BackupKey {
/// Get the backup version that this key is used with, if any.
pub fn backup_version(&self) -> Option<String> {
self.inner.version.lock().unwrap().clone()
self.inner.version.lock().clone()
}
/// Set the backup version that this `MegolmV1BackupKey` will be used with.
@@ -95,7 +96,7 @@ impl MegolmV1BackupKey {
/// The key won't be able to encrypt room keys unless a version has been
/// set.
pub fn set_version(&self, version: String) {
*self.inner.version.lock().unwrap() = Some(version);
*self.inner.version.lock() = Some(version);
}
/// Export the given inbound group session, and encrypt the data, ready for

View File

@@ -25,10 +25,11 @@ use std::{
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock as StdRwLock,
Arc,
},
};
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{
api::client::keys::claim_keys::v3::Request as KeysClaimRequest,
events::secret::request::{
@@ -168,14 +169,13 @@ impl GossipMachine {
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> =
self.inner.outgoing_requests.read().unwrap().values().cloned().collect();
self.inner.outgoing_requests.read().values().cloned().collect();
key_requests.extend(key_forwards);
let users_for_key_claim: BTreeMap<_, _> = self
.inner
.users_for_key_claim
.read()
.unwrap()
.iter()
.map(|(key, value)| {
let device_map = value
@@ -213,7 +213,7 @@ impl GossipMachine {
trace!("Received a secret request event from ourselves, ignoring")
} else {
let request_info = event.to_request_info();
self.inner.incoming_key_requests.write().unwrap().insert(request_info, event);
self.inner.incoming_key_requests.write().insert(request_info, event);
}
}
@@ -229,8 +229,7 @@ impl GossipMachine {
) -> OlmResult<Vec<Session>> {
let mut changed_sessions = Vec::new();
let incoming_key_requests =
mem::take(&mut *self.inner.incoming_key_requests.write().unwrap());
let incoming_key_requests = mem::take(&mut *self.inner.incoming_key_requests.write());
for event in incoming_key_requests.values() {
if let Some(s) = match event {
@@ -254,7 +253,6 @@ impl GossipMachine {
self.inner
.users_for_key_claim
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
@@ -275,7 +273,7 @@ impl GossipMachine {
/// * `device_id` - The device ID of the device that got the Olm session.
pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) {
if let Entry::Occupied(mut e) =
self.inner.users_for_key_claim.write().unwrap().entry(user_id.to_owned())
self.inner.users_for_key_claim.write().entry(user_id.to_owned())
{
e.get_mut().remove(device_id);
@@ -284,7 +282,7 @@ impl GossipMachine {
}
}
let mut incoming_key_requests = self.inner.incoming_key_requests.write().unwrap();
let mut incoming_key_requests = self.inner.incoming_key_requests.write();
for (key, event) in self.inner.wait_queue.remove(user_id, device_id) {
incoming_key_requests.entry(key).or_insert(event);
}
@@ -555,7 +553,7 @@ impl GossipMachine {
request_id: request.txn_id.clone(),
request: Arc::new(request.into()),
};
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
Ok(used_session)
}
@@ -581,7 +579,7 @@ impl GossipMachine {
request_id: request.txn_id.clone(),
request: Arc::new(request.into()),
};
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
Ok(used_session)
}
@@ -824,7 +822,7 @@ impl GossipMachine {
self.save_outgoing_key_info(info).await?;
}
self.inner.outgoing_requests.write().unwrap().remove(id);
self.inner.outgoing_requests.write().remove(id);
Ok(())
}
@@ -840,13 +838,13 @@ impl GossipMachine {
"Successfully received a secret, removing the request"
);
self.inner.outgoing_requests.write().unwrap().remove(&key_info.request_id);
self.inner.outgoing_requests.write().remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(key_info).await?;
let request = key_info.to_cancellation(self.device_id());
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
Ok(())
}
@@ -1511,7 +1509,6 @@ mod tests {
.inner
.outgoing_requests
.read()
.unwrap()
.first_key_value()
.map(|(_, r)| r.request_id.clone())
.unwrap();
@@ -1692,7 +1689,7 @@ mod tests {
alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap();
// Bob doesn't have any outgoing requests.
assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(bob_machine.inner.outgoing_requests.read().is_empty());
// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
@@ -1702,7 +1699,7 @@ mod tests {
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Now bob does have an outgoing request.
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(!bob_machine.inner.outgoing_requests.read().is_empty());
// Get the request and convert it to a encrypted to-device event.
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
@@ -1774,7 +1771,7 @@ mod tests {
alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap();
// Bob doesn't have any outgoing requests.
assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(bob_machine.inner.outgoing_requests.read().is_empty());
// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
@@ -1783,7 +1780,7 @@ mod tests {
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Now bob does have an outgoing request.
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(!bob_machine.inner.outgoing_requests.read().is_empty());
// Get the request and convert it to a encrypted to-device event.
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
@@ -1875,13 +1872,13 @@ mod tests {
};
// No secret found
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
alice_machine.receive_incoming_secret_request(&event);
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
// No device found
alice_machine.inner.store.reset_cross_signing_identity().await;
@@ -1890,7 +1887,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
alice_machine.inner.store.save_device_data(&[bob_device]).await.unwrap();
@@ -1901,7 +1898,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
let event = RumaToDeviceEvent {
sender: alice_id().to_owned(),
@@ -1918,7 +1915,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
// We need a trusted device, otherwise we won't serve secrets
alice_device.set_trust_state(LocalTrust::Verified);
@@ -1929,7 +1926,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(!alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(!alice_machine.inner.outgoing_requests.read().is_empty());
}
#[async_test]
@@ -2053,7 +2050,7 @@ mod tests {
// Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
assert!(bob_machine.inner.users_for_key_claim.read().is_empty());
assert!(bob_machine.inner.wait_queue.is_empty());
// Receive the room key request from alice.
@@ -2068,7 +2065,7 @@ mod tests {
bob_machine.outgoing_to_device_requests().await.unwrap()[0].request(),
AnyOutgoingRequest::KeysClaim(_)
);
assert!(!bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
assert!(!bob_machine.inner.users_for_key_claim.read().is_empty());
assert!(!bob_machine.inner.wait_queue.is_empty());
let (alice_session, bob_session) = alice_machine
@@ -2096,7 +2093,7 @@ mod tests {
bob_machine.inner.store.save_sessions(&[bob_session]).await.unwrap();
bob_machine.retry_keyshare(alice_id(), alice_device_id());
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
assert!(bob_machine.inner.users_for_key_claim.read().is_empty());
{
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();

View File

@@ -16,10 +16,11 @@ mod machine;
use std::{
collections::{BTreeMap, BTreeSet},
sync::{Arc, RwLock as StdRwLock},
sync::Arc,
};
pub(crate) use machine::GossipMachine;
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{
events::{
room_key_request::{Action, ToDeviceRoomKeyRequestEventContent},
@@ -323,7 +324,7 @@ impl WaitQueue {
#[cfg(all(test, feature = "automatic-room-key-forwarding"))]
fn is_empty(&self) -> bool {
let read_guard = self.inner.read().unwrap();
let read_guard = self.inner.read();
read_guard.requests_ids_waiting.is_empty()
&& read_guard.requests_waiting_for_session.is_empty()
}
@@ -337,13 +338,14 @@ impl WaitQueue {
);
let ids_waiting_key = (device.user_id().to_owned(), device.device_id().into());
let mut write_guard = self.inner.write().unwrap();
let mut write_guard = self.inner.write();
write_guard.requests_waiting_for_session.insert(requests_waiting_key, event);
write_guard.requests_ids_waiting.entry(ids_waiting_key).or_default().insert(request_id);
}
fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Vec<(RequestInfo, RequestEvent)> {
let mut write_guard = self.inner.write().unwrap();
let mut write_guard = self.inner.write();
write_guard
.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into()))

View File

@@ -17,11 +17,11 @@ use std::{
ops::Deref,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
Arc,
},
};
use matrix_sdk_common::deserialized_responses::WithheldCode;
use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock};
use ruma::{
api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest,
events::{key::verification::VerificationMethod, AnyToDeviceEventContent},
@@ -470,7 +470,7 @@ impl Device {
) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
let (used_session, raw_encrypted) = self.encrypt(event_type, content).await?;
// perist the used session
// Persist the used session
self.verification_machine
.store
.save_changes(Changes { sessions: vec![used_session], ..Default::default() })
@@ -626,7 +626,7 @@ impl DeviceData {
/// Get the trust state of the device.
pub fn local_trust_state(&self) -> LocalTrust {
*self.trust_state.read().unwrap()
*self.trust_state.read()
}
/// Is the device locally marked as trusted.
@@ -646,7 +646,7 @@ impl DeviceData {
/// Note: This should only done in the crypto store where the trust state
/// can be stored.
pub(crate) fn set_trust_state(&self, state: LocalTrust) {
*self.trust_state.write().unwrap() = state;
*self.trust_state.write() = state;
}
pub(crate) fn mark_withheld_code_as_sent(&self) {

View File

@@ -17,11 +17,12 @@ use std::{
ops::{Deref, DerefMut},
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
Arc,
},
};
use as_variant::as_variant;
use matrix_sdk_common::locks::RwLock;
use ruma::{
api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest,
events::{
@@ -684,7 +685,7 @@ impl From<OtherUserIdentityData> for OtherUserIdentityDataSerializer {
user_id: value.user_id.clone(),
master_key: value.master_key().to_owned(),
self_signing_key: value.self_signing_key().to_owned(),
pinned_master_key: value.pinned_master_key.read().unwrap().clone(),
pinned_master_key: value.pinned_master_key.read().clone(),
previously_verified: value.previously_verified.load(Ordering::SeqCst),
};
OtherUserIdentityDataSerializer {
@@ -787,7 +788,7 @@ impl OtherUserIdentityData {
/// which is not verified and is in pin violation. See
/// [`OtherUserIdentity::identity_needs_user_approval`].
pub(crate) fn pin(&self) {
let mut m = self.pinned_master_key.write().unwrap();
let mut m = self.pinned_master_key.write();
*m = self.master_key.as_ref().clone()
}
@@ -828,7 +829,7 @@ impl OtherUserIdentityData {
/// accept and pin the new identity, perform a verification, or
/// stop communications.
pub(crate) fn has_pin_violation(&self) -> bool {
let pinned_master_key = self.pinned_master_key.read().unwrap();
let pinned_master_key = self.pinned_master_key.read();
pinned_master_key.get_first_key() != self.master_key().get_first_key()
}
@@ -858,7 +859,7 @@ impl OtherUserIdentityData {
// the previous pinned master key.
// This identity will have a pin violation until the new master key is pinned
// (see `has_pin_violation()`).
let pinned_master_key = self.pinned_master_key.read().unwrap().clone();
let pinned_master_key = self.pinned_master_key.read().clone();
// Check if the new master_key is signed by our own **verified**
// user_signing_key. If the identity was verified we remember it.
@@ -947,7 +948,7 @@ impl PartialEq for OwnUserIdentityData {
&& self.master_key == other.master_key
&& self.self_signing_key == other.self_signing_key
&& self.user_signing_key == other.user_signing_key
&& *self.verified.read().unwrap() == *other.verified.read().unwrap()
&& *self.verified.read() == *other.verified.read()
&& self.master_key.signatures() == other.master_key.signatures()
}
}
@@ -1067,12 +1068,12 @@ impl OwnUserIdentityData {
/// Mark our identity as verified.
pub fn mark_as_verified(&self) {
*self.verified.write().unwrap() = OwnUserIdentityVerifiedState::Verified;
*self.verified.write() = OwnUserIdentityVerifiedState::Verified;
}
/// Mark our identity as unverified.
pub(crate) fn mark_as_unverified(&self) {
let mut guard = self.verified.write().unwrap();
let mut guard = self.verified.write();
if *guard == OwnUserIdentityVerifiedState::Verified {
*guard = OwnUserIdentityVerifiedState::VerificationViolation;
}
@@ -1080,7 +1081,7 @@ impl OwnUserIdentityData {
/// Check if our identity is verified.
pub fn is_verified(&self) -> bool {
*self.verified.read().unwrap() == OwnUserIdentityVerifiedState::Verified
*self.verified.read() == OwnUserIdentityVerifiedState::Verified
}
/// True if we verified our own identity at some point in the past.
@@ -1089,7 +1090,7 @@ impl OwnUserIdentityData {
/// [`OwnUserIdentityData::withdraw_verification()`].
pub fn was_previously_verified(&self) -> bool {
matches!(
*self.verified.read().unwrap(),
*self.verified.read(),
OwnUserIdentityVerifiedState::Verified
| OwnUserIdentityVerifiedState::VerificationViolation
)
@@ -1101,7 +1102,7 @@ impl OwnUserIdentityData {
/// reported to the user. In order to remove this notice users have to
/// verify again or to withdraw the verification requirement.
pub fn withdraw_verification(&self) {
let mut guard = self.verified.write().unwrap();
let mut guard = self.verified.write();
if *guard == OwnUserIdentityVerifiedState::VerificationViolation {
*guard = OwnUserIdentityVerifiedState::NeverVerified;
}
@@ -1117,7 +1118,7 @@ impl OwnUserIdentityData {
/// - Or by withdrawing the verification requirement
/// [`OwnUserIdentity::withdraw_verification`].
pub fn has_verification_violation(&self) -> bool {
*self.verified.read().unwrap() == OwnUserIdentityVerifiedState::VerificationViolation
*self.verified.read() == OwnUserIdentityVerifiedState::VerificationViolation
}
/// Update the identity with a new master key and self signing key.
@@ -1523,7 +1524,7 @@ pub(crate) mod tests {
});
let migrated: OtherUserIdentityData = serde_json::from_value(serialized_value).unwrap();
let pinned_master_key = migrated.pinned_master_key.read().unwrap();
let pinned_master_key = migrated.pinned_master_key.read();
assert_eq!(*pinned_master_key, migrated.master_key().clone());
// Serialize back
@@ -1547,12 +1548,12 @@ pub(crate) mod tests {
// Set `"verified": false`
*json.get_mut("verified").unwrap() = false.into();
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
assert_eq!(*id.verified.read().unwrap(), OwnUserIdentityVerifiedState::NeverVerified);
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::NeverVerified);
// Tweak the json to have `"verified": true`, and repeat
*json.get_mut("verified").unwrap() = true.into();
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
assert_eq!(*id.verified.read().unwrap(), OwnUserIdentityVerifiedState::Verified);
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::Verified);
}
#[test]
@@ -1565,10 +1566,7 @@ pub(crate) mod tests {
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
// Then the value is correctly populated
assert_eq!(
*id.verified.read().unwrap(),
OwnUserIdentityVerifiedState::VerificationViolation
);
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::VerificationViolation);
}
#[test]
@@ -1581,10 +1579,7 @@ pub(crate) mod tests {
let id: OwnUserIdentityData = serde_json::from_value(json.clone()).unwrap();
// Then the old value is re-interpreted as VerificationViolation
assert_eq!(
*id.verified.read().unwrap(),
OwnUserIdentityVerifiedState::VerificationViolation
);
assert_eq!(*id.verified.read(), OwnUserIdentityVerifiedState::VerificationViolation);
}
#[test]

View File

@@ -14,7 +14,7 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::{Arc, RwLock as StdRwLock},
sync::Arc,
time::Duration,
};
@@ -25,6 +25,7 @@ use matrix_sdk_common::{
UnableToDecryptReason, UnsignedDecryptionResult, UnsignedEventLocation, VerificationLevel,
VerificationState,
},
locks::RwLock as StdRwLock,
BoxFuture,
};
use ruma::{

View File

@@ -18,12 +18,12 @@ use std::{
fmt,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, RwLock as StdRwLock,
Arc,
},
time::Duration,
};
use matrix_sdk_common::deserialized_responses::WithheldCode;
use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock as StdRwLock};
use ruma::{
events::{
room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility},
@@ -274,7 +274,7 @@ impl OutboundGroupSession {
request: Arc<ToDeviceRequest>,
share_infos: ShareInfoSet,
) {
self.to_share_with_set.write().unwrap().insert(request_id, (request, share_infos));
self.to_share_with_set.write().insert(request_id, (request, share_infos));
}
/// Create a new `m.room_key.withheld` event content with the given code for
@@ -310,7 +310,7 @@ impl OutboundGroupSession {
) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
let mut no_olm_devices = BTreeMap::new();
let removed = self.to_share_with_set.write().unwrap().remove(request_id);
let removed = self.to_share_with_set.write().remove(request_id);
if let Some((to_device, request)) = removed {
let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
.iter()
@@ -332,10 +332,10 @@ impl OutboundGroupSession {
.collect();
no_olm_devices.insert(user_id.to_owned(), no_olms);
self.shared_with_set.write().unwrap().entry(user_id).or_default().extend(info);
self.shared_with_set.write().entry(user_id).or_default().extend(info);
}
if self.to_share_with_set.read().unwrap().is_empty() {
if self.to_share_with_set.read().is_empty() {
debug!(
session_id = self.session_id(),
room_id = ?self.room_id,
@@ -347,7 +347,7 @@ impl OutboundGroupSession {
}
} else {
let request_ids: Vec<String> =
self.to_share_with_set.read().unwrap().keys().map(|k| k.to_string()).collect();
self.to_share_with_set.read().keys().map(|k| k.to_string()).collect();
error!(
all_request_ids = ?request_ids,
@@ -540,22 +540,21 @@ impl OutboundGroupSession {
/// Has or will the session be shared with the given user/device pair.
pub(crate) fn is_shared_with(&self, device: &DeviceData) -> ShareState {
// Check if we shared the session.
let shared_state =
self.shared_with_set.read().unwrap().get(device.user_id()).and_then(|d| {
d.get(device.device_id()).map(|s| match s {
ShareInfo::Shared(s) => {
if device.curve25519_key() == Some(s.sender_key) {
ShareState::Shared {
message_index: s.message_index,
olm_wedging_index: s.olm_wedging_index,
}
} else {
ShareState::SharedButChangedSenderKey
let shared_state = self.shared_with_set.read().get(device.user_id()).and_then(|d| {
d.get(device.device_id()).map(|s| match s {
ShareInfo::Shared(s) => {
if device.curve25519_key() == Some(s.sender_key) {
ShareState::Shared {
message_index: s.message_index,
olm_wedging_index: s.olm_wedging_index,
}
} else {
ShareState::SharedButChangedSenderKey
}
ShareInfo::Withheld(_) => ShareState::NotShared,
})
});
}
ShareInfo::Withheld(_) => ShareState::NotShared,
})
});
if let Some(state) = shared_state {
state
@@ -565,24 +564,23 @@ impl OutboundGroupSession {
// Find the first request that contains the given user id and
// device ID.
let shared =
self.to_share_with_set.read().unwrap().values().find_map(|(_, share_info)| {
let d = share_info.get(device.user_id())?;
let info = d.get(device.device_id())?;
Some(match info {
ShareInfo::Shared(info) => {
if device.curve25519_key() == Some(info.sender_key) {
ShareState::Shared {
message_index: info.message_index,
olm_wedging_index: info.olm_wedging_index,
}
} else {
ShareState::SharedButChangedSenderKey
let shared = self.to_share_with_set.read().values().find_map(|(_, share_info)| {
let d = share_info.get(device.user_id())?;
let info = d.get(device.device_id())?;
Some(match info {
ShareInfo::Shared(info) => {
if device.curve25519_key() == Some(info.sender_key) {
ShareState::Shared {
message_index: info.message_index,
olm_wedging_index: info.olm_wedging_index,
}
} else {
ShareState::SharedButChangedSenderKey
}
ShareInfo::Withheld(_) => ShareState::NotShared,
})
});
}
ShareInfo::Withheld(_) => ShareState::NotShared,
})
});
shared.unwrap_or(ShareState::NotShared)
}
@@ -591,7 +589,6 @@ impl OutboundGroupSession {
pub(crate) fn is_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
self.shared_with_set
.read()
.unwrap()
.get(device.user_id())
.and_then(|d| {
let info = d.get(device.device_id())?;
@@ -603,7 +600,7 @@ impl OutboundGroupSession {
// Find the first request that contains the given user id and
// device ID.
self.to_share_with_set.read().unwrap().values().any(|(_, share_info)| {
self.to_share_with_set.read().values().any(|(_, share_info)| {
share_info
.get(device.user_id())
.and_then(|d| d.get(device.device_id()))
@@ -622,7 +619,7 @@ impl OutboundGroupSession {
sender_key: Curve25519PublicKey,
index: u32,
) {
self.shared_with_set.write().unwrap().entry(user_id.to_owned()).or_default().insert(
self.shared_with_set.write().entry(user_id.to_owned()).or_default().insert(
device_id.to_owned(),
ShareInfo::new_shared(sender_key, index, Default::default()),
);
@@ -641,7 +638,6 @@ impl OutboundGroupSession {
ShareInfo::new_shared(sender_key, self.message_index().await, Default::default());
self.shared_with_set
.write()
.unwrap()
.entry(user_id.to_owned())
.or_default()
.insert(device_id.to_owned(), share_info);
@@ -650,12 +646,12 @@ impl OutboundGroupSession {
/// Get the list of requests that need to be sent out for this session to be
/// marked as shared.
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set.read().unwrap().values().map(|(req, _)| req.clone()).collect()
self.to_share_with_set.read().values().map(|(req, _)| req.clone()).collect()
}
/// Get the list of request ids this session is waiting for to be sent out.
pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
self.to_share_with_set.read().unwrap().keys().cloned().collect()
self.to_share_with_set.read().keys().cloned().collect()
}
/// Restore a Session from a previously pickled string.
@@ -717,8 +713,8 @@ impl OutboundGroupSession {
message_count: self.message_count.load(Ordering::SeqCst),
shared: self.shared(),
invalidated: self.invalidated(),
shared_with_set: self.shared_with_set.read().unwrap().clone(),
requests: self.to_share_with_set.read().unwrap().clone(),
shared_with_set: self.shared_with_set.read().clone(),
requests: self.to_share_with_set.read().clone(),
}
}
}

View File

@@ -17,12 +17,14 @@ mod share_strategy;
use std::{
collections::{BTreeMap, BTreeSet},
fmt::Debug,
sync::{Arc, RwLock as StdRwLock},
sync::Arc,
};
use futures_util::future::join_all;
use itertools::Itertools;
use matrix_sdk_common::{deserialized_responses::WithheldCode, executor::spawn};
use matrix_sdk_common::{
deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
};
use ruma::{
events::{AnyMessageLikeEventContent, ToDeviceEventType},
serde::Raw,
@@ -60,7 +62,7 @@ impl GroupSessionCache {
}
pub(crate) fn insert(&self, session: OutboundGroupSession) {
self.sessions.write().unwrap().insert(session.room_id().to_owned(), session);
self.sessions.write().insert(session.room_id().to_owned(), session);
}
/// Either get a session for the given room from the cache or load it from
@@ -72,20 +74,20 @@ impl GroupSessionCache {
pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
// Get the cached session, if there isn't one load one from the store
// and put it in the cache.
if let Some(s) = self.sessions.read().unwrap().get(room_id) {
if let Some(s) = self.sessions.read().get(room_id) {
return Some(s.clone());
}
match self.store.get_outbound_group_session(room_id).await {
Ok(Some(s)) => {
{
let mut sessions_being_shared = self.sessions_being_shared.write().unwrap();
let mut sessions_being_shared = self.sessions_being_shared.write();
for request_id in s.pending_request_ids() {
sessions_being_shared.insert(request_id, s.clone());
}
}
self.sessions.write().unwrap().insert(room_id.to_owned(), s.clone());
self.sessions.write().insert(room_id.to_owned(), s.clone());
Some(s)
}
@@ -104,20 +106,20 @@ impl GroupSessionCache {
/// * `room_id` - The id of the room for which we should get the outbound
/// group session.
fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
self.sessions.read().unwrap().get(room_id).cloned()
self.sessions.read().get(room_id).cloned()
}
/// Returns whether any session is withheld with the given device and code.
fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
self.sessions.read().unwrap().values().any(|s| s.is_withheld_to(device, code))
self.sessions.read().values().any(|s| s.is_withheld_to(device, code))
}
fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
self.sessions_being_shared.write().unwrap().remove(id)
self.sessions_being_shared.write().remove(id)
}
fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
self.sessions_being_shared.write().unwrap().insert(id, session);
self.sessions_being_shared.write().insert(id, session);
}
}

View File

@@ -125,7 +125,7 @@ pub(crate) async fn collect_session_recipients(
trace!(?users, ?settings, "Calculating group session recipients");
let users_shared_with: BTreeSet<OwnedUserId> =
outbound.shared_with_set.read().unwrap().keys().cloned().collect();
outbound.shared_with_set.read().keys().cloned().collect();
let users_shared_with: BTreeSet<&UserId> = users_shared_with.iter().map(Deref::deref).collect();
@@ -339,7 +339,7 @@ fn is_session_overshared_for_user(
let recipient_device_ids: BTreeSet<&DeviceId> =
recipient_devices.iter().map(|d| d.device_id()).collect();
let guard = outbound_session.shared_with_set.read().unwrap();
let guard = outbound_session.shared_with_set.read();
let Some(shared) = guard.get(user_id) else {
return false;

View File

@@ -14,11 +14,11 @@
use std::{
collections::{BTreeMap, BTreeSet},
sync::{Arc, RwLock as StdRwLock},
sync::Arc,
time::Duration,
};
use matrix_sdk_common::failures_cache::FailuresCache;
use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
use ruma::{
api::client::keys::claim_keys::v3::{
Request as KeysClaimRequest, Response as KeysClaimResponse,
@@ -98,7 +98,7 @@ impl SessionManager {
/// Mark the outgoing request as sent.
pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
self.outgoing_to_device_requests.write().unwrap().remove(id);
self.outgoing_to_device_requests.write().remove(id);
}
pub async fn mark_device_as_wedged(
@@ -121,13 +121,11 @@ impl SessionManager {
if should_unwedge {
self.users_for_key_claim
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
self.wedged_devices
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
@@ -142,7 +140,6 @@ impl SessionManager {
pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
self.wedged_devices
.read()
.unwrap()
.get(device.user_id())
.is_some_and(|d| d.contains(device.device_id()))
}
@@ -151,13 +148,7 @@ impl SessionManager {
///
/// If the device was wedged this will queue up a dummy to-device message.
async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
if self
.wedged_devices
.write()
.unwrap()
.get_mut(user_id)
.is_some_and(|d| d.remove(device_id))
{
if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id)) {
if let Some(device) = self.store.get_device(user_id, device_id).await? {
let (_, content) =
device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
@@ -176,7 +167,6 @@ impl SessionManager {
self.outgoing_to_device_requests
.write()
.unwrap()
.insert(request.request_id.clone(), request);
}
}
@@ -278,7 +268,7 @@ impl SessionManager {
// Add the list of sessions that for some reason automatically need to
// create an Olm session.
for (user, device_ids) in self.users_for_key_claim.read().unwrap().iter() {
for (user, device_ids) in self.users_for_key_claim.read().iter() {
missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
device_ids
.iter()
@@ -319,12 +309,12 @@ impl SessionManager {
// stash the details of the request so that we can refer to it when handling the
// response
*(self.current_key_claim_request.write().unwrap()) = result.clone();
*(self.current_key_claim_request.write()) = result.clone();
Ok(result)
}
fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
self.failed_devices.read().unwrap().get(user_id).is_some_and(|d| d.contains(device_id))
self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
}
/// This method will try to figure out for which devices a one-time key was
@@ -354,7 +344,7 @@ impl SessionManager {
) {
// First check that the response is for the request we were expecting.
let request = {
let mut guard = self.current_key_claim_request.write().unwrap();
let mut guard = self.current_key_claim_request.write();
let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
if Some(request_id) == expected_request_id {
@@ -416,7 +406,7 @@ impl SessionManager {
"Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
);
let mut failed_devices_lock = self.failed_devices.write().unwrap();
let mut failed_devices_lock = self.failed_devices.write();
for (user_id, device_set) in missing_devices_by_user {
failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
@@ -542,7 +532,6 @@ impl SessionManager {
self.failed_devices
.write()
.unwrap()
.entry(user_id.to_owned())
.or_default()
.insert(device_id.to_owned());
@@ -573,7 +562,7 @@ impl SessionManager {
info!(sessions = ?new_sessions, "Established new Olm sessions");
for (user, device_map) in new_sessions {
if let Some(user_cache) = self.failed_devices.read().unwrap().get(user) {
if let Some(user_cache) = self.failed_devices.read().get(user) {
user_cache.remove(device_map.into_keys());
}
}
@@ -597,14 +586,9 @@ impl SessionManager {
#[cfg(test)]
mod tests {
use std::{
collections::BTreeMap,
iter,
ops::Deref,
sync::{Arc, RwLock as StdRwLock},
time::Duration,
};
use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
use matrix_sdk_common::locks::RwLock as StdRwLock;
use matrix_sdk_test::{async_test, ruma_response_from_json};
use ruma::{
api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
@@ -861,11 +845,11 @@ mod tests {
let curve_key = bob_device.curve25519_key().unwrap();
assert!(!manager.users_for_key_claim.read().unwrap().contains_key(bob.user_id()));
assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
assert!(!manager.is_device_wedged(&bob_device));
manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
assert!(manager.is_device_wedged(&bob_device));
assert!(manager.users_for_key_claim.read().unwrap().contains_key(bob.user_id()));
assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
let (txn_id, request) =
manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
@@ -885,13 +869,13 @@ mod tests {
let response = KeyClaimResponse::new(one_time_keys);
assert!(manager.outgoing_to_device_requests.read().unwrap().is_empty());
assert!(manager.outgoing_to_device_requests.read().is_empty());
manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
assert!(!manager.is_device_wedged(&bob_device));
assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
assert!(!manager.outgoing_to_device_requests.read().unwrap().is_empty())
assert!(!manager.outgoing_to_device_requests.read().is_empty())
}
#[async_test]
@@ -1012,7 +996,6 @@ mod tests {
manager
.failed_devices
.write()
.unwrap()
.get(alice)
.unwrap()
.expire(&alice_account.device_id().to_owned());
@@ -1027,7 +1010,6 @@ mod tests {
assert!(manager
.failed_devices
.read()
.unwrap()
.get(alice)
.unwrap()
.failure_count(alice_account.device_id())

View File

@@ -22,10 +22,11 @@ use std::{
fmt::Display,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock as StdRwLock, Weak,
Arc, Weak,
},
};
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
@@ -104,7 +105,6 @@ impl GroupSessionStore {
pub fn add(&self, session: InboundGroupSession) -> bool {
self.entries
.write()
.unwrap()
.entry(session.room_id().to_owned())
.or_default()
.insert(session.session_id().to_owned(), session)
@@ -113,12 +113,12 @@ impl GroupSessionStore {
/// Get all the group sessions the store knows about.
pub fn get_all(&self) -> Vec<InboundGroupSession> {
self.entries.read().unwrap().values().flat_map(HashMap::values).cloned().collect()
self.entries.read().values().flat_map(HashMap::values).cloned().collect()
}
/// Get the number of `InboundGroupSession`s we have.
pub fn count(&self) -> usize {
self.entries.read().unwrap().values().map(HashMap::len).sum()
self.entries.read().values().map(HashMap::len).sum()
}
/// Get a inbound group session from our store.
@@ -128,7 +128,7 @@ impl GroupSessionStore {
///
/// * `session_id` - The unique id of the session.
pub fn get(&self, room_id: &RoomId, session_id: &str) -> Option<InboundGroupSession> {
self.entries.read().unwrap().get(room_id)?.get(session_id).cloned()
self.entries.read().get(room_id)?.get(session_id).cloned()
}
}
@@ -151,7 +151,6 @@ impl DeviceStore {
let user_id = device.user_id();
self.entries
.write()
.unwrap()
.entry(user_id.to_owned())
.or_default()
.insert(device.device_id().into(), device)
@@ -160,7 +159,7 @@ impl DeviceStore {
/// Get the device with the given device_id and belonging to the given user.
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
Some(self.entries.read().unwrap().get(user_id)?.get(device_id)?.clone())
Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
}
/// Remove the device with the given device_id and belonging to the given
@@ -168,14 +167,13 @@ impl DeviceStore {
///
/// Returns the device if it was removed, None if it wasn't in the store.
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
self.entries.write().unwrap().get_mut(user_id)?.remove(device_id)
self.entries.write().get_mut(user_id)?.remove(device_id)
}
/// Get a read-only view over all devices of the given user.
pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
self.entries
.write()
.unwrap()
.entry(user_id.to_owned())
.or_default()
.iter()

View File

@@ -15,11 +15,12 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
convert::Infallible,
sync::RwLock as StdRwLock,
};
use async_trait::async_trait;
use matrix_sdk_common::store_locks::memory_store_helper::try_take_leased_lock;
use matrix_sdk_common::{
locks::RwLock as StdRwLock, store_locks::memory_store_helper::try_take_leased_lock,
};
use ruma::{
events::secret::request::SecretName, time::Instant, DeviceId, OwnedDeviceId, OwnedRoomId,
OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
@@ -143,7 +144,7 @@ impl MemoryStore {
}
fn save_sessions(&self, sessions: Vec<Session>) {
let mut session_store = self.sessions.write().unwrap();
let mut session_store = self.sessions.write();
for session in sessions {
let entry = session_store.entry(session.sender_key().to_base64()).or_default();
@@ -159,12 +160,11 @@ impl MemoryStore {
fn save_outbound_group_sessions(&self, sessions: Vec<OutboundGroupSession>) {
self.outbound_group_sessions
.write()
.unwrap()
.extend(sessions.into_iter().map(|s| (s.room_id().to_owned(), s)));
}
fn save_private_identity(&self, private_identity: Option<PrivateCrossSigningIdentity>) {
*self.private_identity.write().unwrap() = private_identity;
*self.private_identity.write() = private_identity;
}
/// Return all the [`InboundGroupSession`]s we have, paired with the
@@ -176,7 +176,6 @@ impl MemoryStore {
let lookup = |s: &InboundGroupSession| {
self.inbound_group_sessions_backed_up_to
.read()
.unwrap()
.get(&s.room_id)?
.get(s.session_id())
.cloned()
@@ -202,11 +201,11 @@ impl CryptoStore for MemoryStore {
type Error = Infallible;
async fn load_account(&self) -> Result<Option<Account>> {
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone()))
Ok(self.account.read().as_ref().map(|acc| acc.deep_clone()))
}
async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
Ok(self.private_identity.read().unwrap().clone())
Ok(self.private_identity.read().clone())
}
async fn next_batch_token(&self) -> Result<Option<String>> {
@@ -215,7 +214,7 @@ impl CryptoStore for MemoryStore {
async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> {
if let Some(account) = changes.account {
*self.account.write().unwrap() = Some(account);
*self.account.write() = Some(account);
}
Ok(())
@@ -232,7 +231,7 @@ impl CryptoStore for MemoryStore {
self.delete_devices(changes.devices.deleted);
{
let mut identities = self.identities.write().unwrap();
let mut identities = self.identities.write();
for identity in changes.identities.new.into_iter().chain(changes.identities.changed) {
identities.insert(
identity.user_id().to_owned(),
@@ -242,15 +241,15 @@ impl CryptoStore for MemoryStore {
}
{
let mut olm_hashes = self.olm_hashes.write().unwrap();
let mut olm_hashes = self.olm_hashes.write();
for hash in changes.message_hashes {
olm_hashes.entry(hash.sender_key.to_owned()).or_default().insert(hash.hash.clone());
}
}
{
let mut outgoing_key_requests = self.outgoing_key_requests.write().unwrap();
let mut key_requests_by_info = self.key_requests_by_info.write().unwrap();
let mut outgoing_key_requests = self.outgoing_key_requests.write();
let mut key_requests_by_info = self.key_requests_by_info.write();
for key_request in changes.key_requests {
let id = key_request.request_id.clone();
@@ -275,14 +274,14 @@ impl CryptoStore for MemoryStore {
}
{
let mut secret_inbox = self.secret_inbox.write().unwrap();
let mut secret_inbox = self.secret_inbox.write();
for secret in changes.secrets {
secret_inbox.entry(secret.secret_name.to_string()).or_default().push(secret);
}
}
{
let mut direct_withheld_info = self.direct_withheld_info.write().unwrap();
let mut direct_withheld_info = self.direct_withheld_info.write();
for (room_id, data) in changes.withheld_session_info {
for (session_id, event) in data {
direct_withheld_info
@@ -298,7 +297,7 @@ impl CryptoStore for MemoryStore {
}
if !changes.room_settings.is_empty() {
let mut settings = self.room_settings.write().unwrap();
let mut settings = self.room_settings.write();
settings.extend(changes.room_settings);
}
@@ -311,7 +310,7 @@ impl CryptoStore for MemoryStore {
backed_up_to_version: Option<&str>,
) -> Result<()> {
let mut inbound_group_sessions_backed_up_to =
self.inbound_group_sessions_backed_up_to.write().unwrap();
self.inbound_group_sessions_backed_up_to.write();
for session in sessions {
let room_id = session.room_id();
@@ -339,7 +338,7 @@ impl CryptoStore for MemoryStore {
}
async fn get_sessions(&self, sender_key: &str) -> Result<Option<Vec<Session>>> {
Ok(self.sessions.read().unwrap().get(sender_key).cloned())
Ok(self.sessions.read().get(sender_key).cloned())
}
async fn get_inbound_group_session(
@@ -358,7 +357,6 @@ impl CryptoStore for MemoryStore {
Ok(self
.direct_withheld_info
.read()
.unwrap()
.get(room_id)
.and_then(|e| Some(e.get(session_id)?.to_owned())))
}
@@ -458,7 +456,7 @@ impl CryptoStore for MemoryStore {
room_and_session_ids: &[(&RoomId, &str)],
) -> Result<()> {
let mut inbound_group_sessions_backed_up_to =
self.inbound_group_sessions_backed_up_to.write().unwrap();
self.inbound_group_sessions_backed_up_to.write();
for &(room_id, session_id) in room_and_session_ids {
let session = self.inbound_group_sessions.get(room_id, session_id);
@@ -506,15 +504,15 @@ impl CryptoStore for MemoryStore {
&self,
room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> {
Ok(self.outbound_group_sessions.read().unwrap().get(room_id).cloned())
Ok(self.outbound_group_sessions.read().get(room_id).cloned())
}
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
Ok(self.tracked_users.read().unwrap().values().cloned().collect())
Ok(self.tracked_users.read().values().cloned().collect())
}
async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
self.tracked_users.write().unwrap().extend(tracked_users.iter().map(|(user_id, dirty)| {
self.tracked_users.write().extend(tracked_users.iter().map(|(user_id, dirty)| {
let user_id: OwnedUserId = user_id.to_owned().into();
(user_id.clone(), TrackedUser { user_id, dirty: *dirty })
}));
@@ -543,7 +541,7 @@ impl CryptoStore for MemoryStore {
}
async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<UserIdentityData>> {
let serialized = self.identities.read().unwrap().get(user_id).cloned();
let serialized = self.identities.read().get(user_id).cloned();
match serialized {
None => Ok(None),
Some(serialized) => {
@@ -557,7 +555,6 @@ impl CryptoStore for MemoryStore {
Ok(self
.olm_hashes
.write()
.unwrap()
.entry(message_hash.sender_key.to_owned())
.or_default()
.contains(&message_hash.hash))
@@ -567,7 +564,7 @@ impl CryptoStore for MemoryStore {
&self,
request_id: &TransactionId,
) -> Result<Option<GossipRequest>> {
Ok(self.outgoing_key_requests.read().unwrap().get(request_id).cloned())
Ok(self.outgoing_key_requests.read().get(request_id).cloned())
}
async fn get_secret_request_by_info(
@@ -579,16 +576,14 @@ impl CryptoStore for MemoryStore {
Ok(self
.key_requests_by_info
.read()
.unwrap()
.get(&key_info_string)
.and_then(|i| self.outgoing_key_requests.read().unwrap().get(i).cloned()))
.and_then(|i| self.outgoing_key_requests.read().get(i).cloned()))
}
async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
Ok(self
.outgoing_key_requests
.read()
.unwrap()
.values()
.filter(|req| !req.sent_out)
.cloned()
@@ -596,10 +591,10 @@ impl CryptoStore for MemoryStore {
}
async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
let req = self.outgoing_key_requests.write().unwrap().remove(request_id);
let req = self.outgoing_key_requests.write().remove(request_id);
if let Some(i) = req {
let key_info_string = encode_key_info(&i.info);
self.key_requests_by_info.write().unwrap().remove(&key_info_string);
self.key_requests_by_info.write().remove(&key_info_string);
}
Ok(())
@@ -609,36 +604,30 @@ impl CryptoStore for MemoryStore {
&self,
secret_name: &SecretName,
) -> Result<Vec<GossippedSecret>> {
Ok(self
.secret_inbox
.write()
.unwrap()
.entry(secret_name.to_string())
.or_default()
.to_owned())
Ok(self.secret_inbox.write().entry(secret_name.to_string()).or_default().to_owned())
}
async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> {
self.secret_inbox.write().unwrap().remove(secret_name.as_str());
self.secret_inbox.write().remove(secret_name.as_str());
Ok(())
}
async fn get_room_settings(&self, room_id: &RoomId) -> Result<Option<RoomSettings>> {
Ok(self.room_settings.read().unwrap().get(room_id).cloned())
Ok(self.room_settings.read().get(room_id).cloned())
}
async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
Ok(self.custom_values.read().unwrap().get(key).cloned())
Ok(self.custom_values.read().get(key).cloned())
}
async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
self.custom_values.write().unwrap().insert(key.to_owned(), value);
self.custom_values.write().insert(key.to_owned(), value);
Ok(())
}
async fn remove_custom_value(&self, key: &str) -> Result<()> {
self.custom_values.write().unwrap().remove(key);
self.custom_values.write().remove(key);
Ok(())
}
@@ -648,7 +637,7 @@ impl CryptoStore for MemoryStore {
key: &str,
holder: &str,
) -> Result<bool> {
Ok(try_take_leased_lock(&mut self.leases.write().unwrap(), lease_duration_ms, key, holder))
Ok(try_take_leased_lock(&mut self.leases.write(), lease_duration_ms, key, holder))
}
}
@@ -1167,7 +1156,7 @@ mod integration_tests {
impl MemoryStore {
fn get_static_account(&self) -> Option<StaticAccountData> {
self.account.read().unwrap().as_ref().map(|acc| acc.static_data().clone())
self.account.read().as_ref().map(|acc| acc.static_data().clone())
}
}

View File

@@ -43,13 +43,14 @@ use std::{
fmt::Debug,
ops::Deref,
pin::pin,
sync::{atomic::Ordering, Arc, RwLock as StdRwLock},
sync::{atomic::Ordering, Arc},
time::Duration,
};
use as_variant::as_variant;
use futures_core::Stream;
use futures_util::StreamExt;
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{
encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
OwnedRoomId, OwnedUserId, UserId,
@@ -155,7 +156,7 @@ impl KeyQueryManager {
let tracked_users = cache.store.load_tracked_users().await?;
let mut query_users_lock = self.users_for_key_query.lock().await;
let mut tracked_users_cache = cache.tracked_users.write().unwrap();
let mut tracked_users_cache = cache.tracked_users.write();
for user in tracked_users {
tracked_users_cache.insert(user.user_id.to_owned());
@@ -245,7 +246,7 @@ impl SyncedKeyQueryManager<'_> {
let mut key_query_lock = self.manager.users_for_key_query.lock().await;
{
let mut tracked_users = self.cache.tracked_users.write().unwrap();
let mut tracked_users = self.cache.tracked_users.write();
for user_id in users {
if tracked_users.insert(user_id.to_owned()) {
key_query_lock.insert_user(user_id);
@@ -271,7 +272,7 @@ impl SyncedKeyQueryManager<'_> {
let mut key_query_lock = self.manager.users_for_key_query.lock().await;
{
let tracked_users = &self.cache.tracked_users.read().unwrap();
let tracked_users = &self.cache.tracked_users.read();
for user_id in users {
if tracked_users.contains(user_id) {
key_query_lock.insert_user(user_id);
@@ -297,7 +298,7 @@ impl SyncedKeyQueryManager<'_> {
let mut key_query_lock = self.manager.users_for_key_query.lock().await;
{
let tracked_users = self.cache.tracked_users.read().unwrap();
let tracked_users = self.cache.tracked_users.read();
for user_id in users {
if tracked_users.contains(user_id) {
let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
@@ -330,7 +331,7 @@ impl SyncedKeyQueryManager<'_> {
/// See the docs for [`crate::OlmMachine::tracked_users()`].
pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
self.cache.tracked_users.read().unwrap().iter().cloned().collect()
self.cache.tracked_users.read().iter().cloned().collect()
}
/// Mark the given user as being tracked for device lists, and mark that it
@@ -340,7 +341,7 @@ impl SyncedKeyQueryManager<'_> {
/// next time [`Store::users_for_key_query()`] is called.
pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
self.manager.users_for_key_query.lock().await.insert_user(user);
self.cache.tracked_users.write().unwrap().insert(user.to_owned());
self.cache.tracked_users.write().insert(user.to_owned());
self.cache.store.save_tracked_users(&[(user, true)]).await
}

View File

@@ -12,12 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::BTreeMap,
sync::{Arc, RwLock as StdRwLock},
};
use std::{collections::BTreeMap, sync::Arc};
use as_variant::as_variant;
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId};
#[cfg(feature = "qrcode")]
use tracing::debug;
@@ -69,7 +67,7 @@ impl VerificationCache {
#[cfg(test)]
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.inner.verification.read().unwrap().values().all(|m| m.is_empty())
self.inner.verification.read().values().all(|m| m.is_empty())
}
/// Add a new `Verification` object to the cache, this will cancel any
@@ -78,7 +76,7 @@ impl VerificationCache {
pub fn insert(&self, verification: impl Into<Verification>) {
let verification = verification.into();
let mut verification_write_guard = self.inner.verification.write().unwrap();
let mut verification_write_guard = self.inner.verification.write();
let user_verifications =
verification_write_guard.entry(verification.other_user().to_owned()).or_default();
@@ -150,22 +148,21 @@ impl VerificationCache {
self.inner
.verification
.write()
.unwrap()
.entry(verification.other_user().to_owned())
.or_default()
.insert(verification.flow_id().to_owned(), verification.clone());
}
pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
self.inner.verification.read().unwrap().get(sender)?.get(flow_id).cloned()
self.inner.verification.read().get(sender)?.get(flow_id).cloned()
}
pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
self.inner.outgoing_requests.read().unwrap().values().cloned().collect()
self.inner.outgoing_requests.read().values().cloned().collect()
}
pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
let verification = &mut self.inner.verification.write().unwrap();
let verification = &mut self.inner.verification.write();
for user_verification in verification.values_mut() {
user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
@@ -186,7 +183,7 @@ impl VerificationCache {
pub fn add_request(&self, request: OutgoingRequest) {
trace!("Adding an outgoing request {:?}", request);
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
}
pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
@@ -211,7 +208,7 @@ impl VerificationCache {
"Storing the request info, waiting for the request to be marked as sent"
);
self.inner.flow_ids_waiting_for_response.write().unwrap().insert(
self.inner.flow_ids_waiting_for_response.write().insert(
request_info.request_id.to_owned(),
(recipient.to_owned(), request_info.flow_id),
);
@@ -235,7 +232,7 @@ impl VerificationCache {
request: Arc::new(request.into()),
};
self.inner.outgoing_requests.write().unwrap().insert(request_id, request);
self.inner.outgoing_requests.write().insert(request_id, request);
}
OutgoingContent::Room(r, c) => {
@@ -247,18 +244,18 @@ impl VerificationCache {
request_id: request_id.clone(),
};
self.inner.outgoing_requests.write().unwrap().insert(request_id, request);
self.inner.outgoing_requests.write().insert(request_id, request);
}
}
}
pub fn mark_request_as_sent(&self, request_id: &TransactionId) {
if let Some(request_id) = self.inner.outgoing_requests.write().unwrap().remove(request_id) {
if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) {
trace!(?request_id, "Marking a verification HTTP request as sent");
}
if let Some((user_id, flow_id)) =
self.inner.flow_ids_waiting_for_response.read().unwrap().get(request_id)
self.inner.flow_ids_waiting_for_response.read().get(request_id)
{
if let Some(verification) = self.get(user_id, flow_id.as_str()) {
match verification {

View File

@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap,
sync::{Arc, RwLock as StdRwLock},
};
use std::{collections::HashMap, sync::Arc};
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{
events::{
key::verification::VerificationMethod, AnyToDeviceEvent, AnyToDeviceEventContent,
@@ -153,13 +151,12 @@ impl VerificationMachine {
user_id: &UserId,
flow_id: impl AsRef<str>,
) -> Option<VerificationRequest> {
self.requests.read().unwrap().get(user_id)?.get(flow_id.as_ref()).cloned()
self.requests.read().get(user_id)?.get(flow_id.as_ref()).cloned()
}
pub fn get_requests(&self, user_id: &UserId) -> Vec<VerificationRequest> {
self.requests
.read()
.unwrap()
.get(user_id)
.map(|v| v.iter().map(|(_, value)| value.clone()).collect())
.unwrap_or_default()
@@ -174,7 +171,7 @@ impl VerificationMachine {
return;
}
let mut requests = self.requests.write().unwrap();
let mut requests = self.requests.write();
let user_requests = requests.entry(request.other_user().to_owned()).or_default();
// Cancel all the old verifications requests as well as the new one we
@@ -247,7 +244,7 @@ impl VerificationMachine {
let mut events = vec![];
let mut requests: Vec<OutgoingVerificationRequest> = {
let mut requests = self.requests.write().unwrap();
let mut requests = self.requests.write();
for user_verification in requests.values_mut() {
user_verification.retain(|_, v| !(v.is_done() || v.is_cancelled()));

View File

@@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
matches,
sync::{Arc, Mutex},
time::Duration,
};
use std::{matches, sync::Arc, time::Duration};
use matrix_sdk_common::locks::Mutex;
use ruma::{
events::{
key::verification::{
@@ -356,7 +353,7 @@ impl<S: Clone> SasState<S> {
let their_public_key = Curve25519PublicKey::from_slice(content.public_key().as_bytes())
.map_err(|_| CancelCode::from("Invalid public key"))?;
if let Some(sas) = self.inner.lock().unwrap().take() {
if let Some(sas) = self.inner.lock().take() {
sas.diffie_hellman(their_public_key).map_err(|_| "Invalid public key".into())
} else {
Err(CancelCode::UnexpectedMessage)
@@ -1127,7 +1124,7 @@ impl SasState<KeysExchanged> {
/// second element the English description of the emoji.
pub fn get_emoji(&self) -> [Emoji; 7] {
get_emoji(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
self.state.we_started,
@@ -1140,7 +1137,7 @@ impl SasState<KeysExchanged> {
/// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
self.state.we_started,
@@ -1153,7 +1150,7 @@ impl SasState<KeysExchanged> {
/// the short auth string.
pub fn get_decimal(&self) -> (u16, u16, u16) {
get_decimal(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
self.state.we_started,
@@ -1175,7 +1172,7 @@ impl SasState<KeysExchanged> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let (devices, master_keys) = receive_mac_event(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
sender,
@@ -1239,7 +1236,7 @@ impl SasState<Confirmed> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let (devices, master_keys) = receive_mac_event(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
sender,
@@ -1283,7 +1280,7 @@ impl SasState<Confirmed> {
self.check_event(sender, content.flow_id()).map_err(|c| self.clone().cancel(true, c))?;
let (devices, master_keys) = receive_mac_event(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
sender,
@@ -1315,7 +1312,7 @@ impl SasState<Confirmed> {
/// The content needs to be automatically sent to the other side.
pub fn as_content(&self) -> OutgoingContent {
get_mac_content(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
&self.verification_flow_id,
self.state.accepted_protocols.message_auth_code,
@@ -1376,7 +1373,7 @@ impl SasState<MacReceived> {
/// second element the English description of the emoji.
pub fn get_emoji(&self) -> [Emoji; 7] {
get_emoji(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
self.state.we_started,
@@ -1389,7 +1386,7 @@ impl SasState<MacReceived> {
/// numbers can be converted to a unique emoji defined by the spec.
pub fn get_emoji_index(&self) -> [u8; 7] {
get_emoji_index(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
self.state.we_started,
@@ -1402,7 +1399,7 @@ impl SasState<MacReceived> {
/// the short auth string.
pub fn get_decimal(&self) -> (u16, u16, u16) {
get_decimal(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
self.verification_flow_id.as_str(),
self.state.we_started,
@@ -1417,7 +1414,7 @@ impl SasState<WaitingForDone> {
/// wasn't already sent.
pub fn as_content(&self) -> OutgoingContent {
get_mac_content(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
&self.verification_flow_id,
self.state.accepted_protocols.message_auth_code,
@@ -1480,7 +1477,7 @@ impl SasState<Done> {
/// wasn't already sent.
pub fn as_content(&self) -> OutgoingContent {
get_mac_content(
&self.state.sas.lock().unwrap(),
&self.state.sas.lock(),
&self.ids,
&self.verification_flow_id,
self.state.accepted_protocols.message_auth_code,