From d38995b3f912a1138a7ea4655abab98c3a215c62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Apr 2022 14:07:03 +0200 Subject: [PATCH 01/10] fix(crypto): Add our own device to the store when we create a new account --- benchmarks/benches/crypto_bench.rs | 8 +-- crates/matrix-sdk-crypto/README.md | 4 +- crates/matrix-sdk-crypto/src/backups/mod.rs | 2 +- .../src/file_encryption/key_export.rs | 4 +- .../src/identities/device.rs | 24 ++++++--- crates/matrix-sdk-crypto/src/machine.rs | 50 ++++++++++--------- .../src/session_manager/group_sessions.rs | 2 +- 7 files changed, 52 insertions(+), 42 deletions(-) diff --git a/benchmarks/benches/crypto_bench.rs b/benchmarks/benches/crypto_bench.rs index 4d0311f49..59017d64b 100644 --- a/benchmarks/benches/crypto_bench.rs +++ b/benchmarks/benches/crypto_bench.rs @@ -51,7 +51,7 @@ fn huge_keys_query_response() -> get_keys::v3::Response { pub fn keys_query(c: &mut Criterion) { let runtime = Builder::new_multi_thread().build().expect("Can't create runtime"); - let machine = OlmMachine::new(alice_id(), alice_device_id()); + let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id())); let response = keys_query_response(); let txn_id = TransactionId::new(); @@ -101,7 +101,7 @@ pub fn keys_claiming(c: &mut Criterion) { group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| { b.iter_batched( || { - let machine = OlmMachine::new(alice_id(), alice_device_id()); + let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id())); runtime .block_on(machine.mark_request_as_sent(&txn_id, &keys_query_response)) .unwrap(); @@ -151,7 +151,7 @@ pub fn room_key_sharing(c: &mut Criterion) { let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len()); - let machine = OlmMachine::new(alice_id(), alice_device_id()); + let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id())); runtime.block_on(machine.mark_request_as_sent(&txn_id, &keys_query_response)).unwrap(); runtime.block_on(machine.mark_request_as_sent(&txn_id, &response)).unwrap(); @@ -214,7 +214,7 @@ pub fn room_key_sharing(c: &mut Criterion) { pub fn devices_missing_sessions_collecting(c: &mut Criterion) { let runtime = Builder::new_multi_thread().build().expect("Can't create runtime"); - let machine = OlmMachine::new(alice_id(), alice_device_id()); + let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id())); let response = huge_keys_query_response(); let txn_id = TransactionId::new(); let users: Vec = response.device_keys.keys().cloned().collect(); diff --git a/crates/matrix-sdk-crypto/README.md b/crates/matrix-sdk-crypto/README.md index e5b2e2690..9891105ba 100644 --- a/crates/matrix-sdk-crypto/README.md +++ b/crates/matrix-sdk-crypto/README.md @@ -28,7 +28,7 @@ use ruma::{ #[tokio::main] async fn main() -> Result<(), OlmError> { let alice = user_id!("@alice:example.org"); - let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); + let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; let to_device_events = ToDevice::default(); let changed_devices = DeviceLists::default(); @@ -69,4 +69,4 @@ The following crate feature flags are available: * `qrcode`: Enbles QRcode generation and reading code -* `testing`: provides facilities and functions for tests, in particular for integration testing store implementations. ATTENTION: do not ever use outside of tests, we do not provide any stability warantees on these, these are merely helpers. If you find you _need_ any function provided here outside of tests, please open a Github Issue and inform us about your use case for us to consider. \ No newline at end of file +* `testing`: provides facilities and functions for tests, in particular for integration testing store implementations. ATTENTION: do not ever use outside of tests, we do not provide any stability warantees on these, these are merely helpers. If you find you _need_ any function provided here outside of tests, please open a Github Issue and inform us about your use case for us to consider. diff --git a/crates/matrix-sdk-crypto/src/backups/mod.rs b/crates/matrix-sdk-crypto/src/backups/mod.rs index bf83fb7ae..15494822a 100644 --- a/crates/matrix-sdk-crypto/src/backups/mod.rs +++ b/crates/matrix-sdk-crypto/src/backups/mod.rs @@ -459,7 +459,7 @@ mod tests { #[async_test] async fn memory_store_backups() -> Result<(), OlmError> { - let machine = OlmMachine::new(alice_id(), alice_device_id()); + let machine = OlmMachine::new(alice_id(), alice_device_id()).await; backup_flow(machine).await } diff --git a/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs b/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs index c3bbb3e06..42cdb63f1 100644 --- a/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs +++ b/crates/matrix-sdk-crypto/src/file_encryption/key_export.rs @@ -80,8 +80,8 @@ pub enum KeyExportError { /// # use ruma::{device_id, user_id}; /// # use futures::executor::block_on; /// # let alice = user_id!("@alice:example.org"); -/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); /// # block_on(async { +/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; /// # let export = Cursor::new("".to_owned()); /// let exported_keys = decrypt_key_export(export, "1234").unwrap(); /// machine.import_keys(exported_keys, false, |_, _| {}).await.unwrap(); @@ -137,8 +137,8 @@ pub fn decrypt_key_export( /// # use ruma::{device_id, user_id, room_id}; /// # use futures::executor::block_on; /// # let alice = user_id!("@alice:example.org"); -/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); /// # block_on(async { +/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; /// let room_id = room_id!("!test:localhost"); /// let exported_keys = machine.export_keys(|s| s.room_id() == room_id).await.unwrap(); /// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1); diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 3105fc681..3103b0575 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -40,6 +40,8 @@ use tracing::warn; use vodozemac::{Curve25519PublicKey, Ed25519PublicKey}; use super::{atomic_bool_deserializer, atomic_bool_serializer}; +#[cfg(any(test, feature = "testing"))] +use crate::OlmMachine; use crate::{ error::{EventError, OlmError, OlmResult, SignatureError}, identities::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities}, @@ -47,10 +49,8 @@ use crate::{ store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult}, types::{DeviceKey, DeviceKeys, SignedKey}, verification::VerificationMachine, - OutgoingVerificationRequest, Sas, ToDeviceRequest, VerificationRequest, + OutgoingVerificationRequest, ReadOnlyAccount, Sas, ToDeviceRequest, VerificationRequest, }; -#[cfg(any(test, feature = "testing"))] -use crate::{OlmMachine, ReadOnlyAccount}; /// A read-only version of a `Device`. #[derive(Clone, Serialize, Deserialize)] @@ -601,14 +601,22 @@ impl ReadOnlyDevice { ReadOnlyDevice::from_account(machine.account()).await } - #[cfg(any(test, feature = "testing"))] - #[allow(dead_code)] - /// Generate the Device from the reference of an Account. + /// Create a `ReadOnlyDevice` from an `Account` /// - /// TESTING FACILITY ONLY, DO NOT USE OUTSIDE OF TESTS + /// We will have our own device in the store once we receive a keys/query + /// response, but this is useful to create it before we receive such a + /// response. + /// + /// It also makes it easier to check that the server doesn't lie about our + /// own device. + /// + /// *Don't* use this after we received a keys/query response, other + /// users/devices might add signatures to our own device, which can't be + /// replicated locally. pub async fn from_account(account: &ReadOnlyAccount) -> ReadOnlyDevice { let device_keys = account.device_keys().await; - ReadOnlyDevice::try_from(&device_keys).unwrap() + ReadOnlyDevice::try_from(&device_keys) + .expect("Creating a device from our own account should always succeed") } } diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index fcf984fed..a97728b8e 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -67,7 +67,7 @@ use crate::{ SecretImportError, Store, }, verification::{Verification, VerificationMachine, VerificationRequest}, - CrossSigningKeyExport, RoomKeyImportResult, ToDeviceRequest, + CrossSigningKeyExport, ReadOnlyDevice, RoomKeyImportResult, ToDeviceRequest, }; /// State machine implementation of the Olm/Megolm encryption protocol used for @@ -128,17 +128,12 @@ impl OlmMachine { /// * `user_id` - The unique id of the user that owns this machine. /// /// * `device_id` - The unique id of the device that owns this machine. - pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self { + pub async fn new(user_id: &UserId, device_id: &DeviceId) -> Self { let store: Box = Box::new(MemoryStore::new()); - let account = ReadOnlyAccount::new(user_id, device_id); - OlmMachine::new_helper( - user_id, - device_id, - store, - account, - PrivateCrossSigningIdentity::empty(user_id), - ) + OlmMachine::with_store(user_id, device_id, store) + .await + .expect("Reading and writing to the memory store always succeeds") } fn new_helper( @@ -233,11 +228,18 @@ impl OlmMachine { } None => { let account = ReadOnlyAccount::new(user_id, device_id); + let device = ReadOnlyDevice::from_account(&account).await; + debug!( ed25519_key = account.identity_keys().ed25519.to_base64().as_str(), "Created a new Olm account" ); - store.save_account(account.clone()).await?; + let changes = Changes { + account: Some(account.clone()), + devices: DeviceChanges { new: vec![device], ..Default::default() }, + ..Default::default() + }; + store.save_changes(changes).await?; account } }; @@ -1178,8 +1180,8 @@ impl OlmMachine { /// # use ruma::{device_id, user_id}; /// # use futures::executor::block_on; /// # let alice = user_id!("@alice:example.org").to_owned(); - /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); /// # block_on(async { + /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; /// let device = machine.get_device(&alice, device_id!("DEVICEID")).await; /// /// println!("{:?}", device); @@ -1219,8 +1221,8 @@ impl OlmMachine { /// # use ruma::{device_id, user_id}; /// # use futures::executor::block_on; /// # let alice = user_id!("@alice:example.org").to_owned(); - /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); /// # block_on(async { + /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; /// let devices = machine.get_user_devices(&alice).await.unwrap(); /// /// for device in devices.devices() { @@ -1255,8 +1257,8 @@ impl OlmMachine { /// # use ruma::{device_id, user_id}; /// # use futures::executor::block_on; /// # let alice = user_id!("@alice:example.org"); - /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); /// # block_on(async { + /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; /// # let export = Cursor::new("".to_owned()); /// let exported_keys = decrypt_key_export(export, "1234").unwrap(); /// machine.import_keys(exported_keys, false, |_, _| {}).await.unwrap(); @@ -1367,8 +1369,8 @@ impl OlmMachine { /// # use ruma::{device_id, user_id, room_id}; /// # use futures::executor::block_on; /// # let alice = user_id!("@alice:example.org"); - /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")); /// # block_on(async { + /// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; /// let room_id = room_id!("!test:localhost"); /// let exported_keys = machine.export_keys(|s| s.room_id() == room_id).await.unwrap(); /// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1); @@ -1598,7 +1600,7 @@ pub(crate) mod tests { } pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; machine.account.inner.update_uploaded_key_count(0); let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload"); let response = keys_upload_response(); @@ -1621,7 +1623,7 @@ pub(crate) mod tests { let alice_id = alice_id(); let alice_device = alice_device_id(); - let alice = OlmMachine::new(alice_id, alice_device); + let alice = OlmMachine::new(alice_id, alice_device).await; let alice_device = ReadOnlyDevice::from_machine(&alice).await; let bob_device = ReadOnlyDevice::from_machine(&bob).await; @@ -1672,13 +1674,13 @@ pub(crate) mod tests { #[async_test] async fn create_olm_machine() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; assert!(!machine.account().shared()); } #[async_test] async fn generate_one_time_keys() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; assert!(machine.account.generate_one_time_keys().await.is_some()); @@ -1694,7 +1696,7 @@ pub(crate) mod tests { #[async_test] async fn test_device_key_signing() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; let mut device_keys = machine.account.device_keys().await; let identity_keys = machine.account.identity_keys(); @@ -1710,7 +1712,7 @@ pub(crate) mod tests { #[async_test] async fn tests_session_invalidation() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; let room_id = room_id!("!test:example.org"); machine.create_outbound_group_session_with_defaults(room_id).await.unwrap(); @@ -1727,7 +1729,7 @@ pub(crate) mod tests { #[async_test] async fn test_invalid_signature() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; let mut device_keys = machine.account.device_keys().await; @@ -1743,7 +1745,7 @@ pub(crate) mod tests { #[async_test] async fn one_time_key_signing() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; machine.account.inner.update_uploaded_key_count(49); let mut one_time_keys = machine.account.signed_one_time_keys().await; @@ -1763,7 +1765,7 @@ pub(crate) mod tests { #[async_test] async fn test_keys_for_upload() { - let machine = OlmMachine::new(user_id(), alice_device_id()); + let machine = OlmMachine::new(user_id(), alice_device_id()).await; machine.account.inner.update_uploaded_key_count(0); let ed25519_key = machine.account.identity_keys().ed25519; diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index a875041d4..47376fa61 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -631,7 +631,7 @@ mod tests { let keys_claim = keys_claim_response(); let txn_id = TransactionId::new(); - let machine = OlmMachine::new(alice_id(), alice_device_id()); + let machine = OlmMachine::new(alice_id(), alice_device_id()).await; machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap(); machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap(); From a425ddf97fe58f9b7e85ec38232a407c4319751e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Apr 2022 14:20:34 +0200 Subject: [PATCH 02/10] fix(crypto): Don't try to encrypt a room key for our own device --- crates/matrix-sdk-crypto/src/machine.rs | 2 +- .../src/session_manager/group_sessions.rs | 47 +++++++++++++++++-- crates/matrix-sdk-crypto/src/store/mod.rs | 16 ++++++- 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index a97728b8e..add0cf4a7 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -92,7 +92,7 @@ pub struct OlmMachine { /// A state machine that handles Olm sessions creation. session_manager: SessionManager, /// A state machine that keeps track of our outbound group sessions. - group_session_manager: GroupSessionManager, + pub(crate) group_session_manager: GroupSessionManager, /// A state machine that is responsible to handle and keep track of SAS /// verification flows. verification_machine: VerificationMachine, diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index 47376fa61..6d3dd6bec 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -359,7 +359,7 @@ impl GroupSessionManager { let mut should_rotate = user_left || visibility_changed; for user_id in users { - let user_devices = self.store.get_user_devices(user_id).await?; + let user_devices = self.store.get_user_devices_filtered(user_id).await?; let non_blacklisted_devices: Vec = user_devices.devices().filter(|d| !d.is_blacklisted()).collect(); @@ -596,7 +596,9 @@ mod tests { client::keys::{claim_keys, get_keys}, IncomingResponse, }, - device_id, room_id, user_id, DeviceId, TransactionId, UserId, + device_id, + events::room::history_visibility::HistoryVisibility, + room_id, user_id, DeviceId, TransactionId, UserId, }; use serde_json::Value; @@ -626,12 +628,12 @@ mod tests { .expect("Can't parse the keys upload response") } - async fn machine() -> OlmMachine { + async fn machine_with_user(user_id: &UserId, device_id: &DeviceId) -> OlmMachine { let keys_query = keys_query_response(); let keys_claim = keys_claim_response(); let txn_id = TransactionId::new(); - let machine = OlmMachine::new(alice_id(), alice_device_id()).await; + let machine = OlmMachine::new(user_id, device_id).await; machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap(); machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap(); @@ -639,6 +641,10 @@ mod tests { machine } + async fn machine() -> OlmMachine { + machine_with_user(alice_id(), alice_device_id()).await + } + #[async_test] async fn test_sharing() { let machine = machine().await; @@ -659,4 +665,37 @@ mod tests { // that all 148 valid sessions get an room key. assert_eq!(event_count, 148); } + + #[async_test] + async fn key_recipient_collecting() { + // The user id comes from the fact that the keys_query.json file uses + // this one. + let user_id = user_id!("@example:localhost"); + let device_id = device_id!("TESTDEVICE"); + let room_id = room_id!("!test:localhost"); + + let machine = machine_with_user(user_id, device_id).await; + + let (outbound, _) = machine + .group_session_manager + .get_or_create_outbound_session(room_id, EncryptionSettings::default()) + .await + .expect("We should be able to create a new session"); + let history_visibility = HistoryVisibility::Joined; + + let users = [user_id].into_iter(); + + let (_, recipients) = machine + .group_session_manager + .collect_session_recipients(users, history_visibility, &outbound) + .await + .expect("We should be able to collect the session recipients"); + + assert!(!recipients.is_empty()); + + // Make sure that our own device isn't part of the recipients. + assert!(!recipients[user_id] + .iter() + .any(|d| d.user_id() == user_id && d.device_id() == device_id)); + } } diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 3adba7589..2c1d8f9d9 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -387,8 +387,22 @@ impl Store { } /// Get all devices associated with the given `user_id` + /// + /// *Note*: This doesn't return our own device. + pub async fn get_user_devices_filtered(&self, user_id: &UserId) -> Result { + self.get_user_devices(user_id).await.map(|mut d| { + if user_id == self.user_id() { + d.inner.remove(self.device_id()); + } + d + }) + } + + /// Get all devices associated with the given `user_id` + /// + /// *Note*: This does also return our own device. pub async fn get_user_devices(&self, user_id: &UserId) -> Result { - let devices = self.inner.get_user_devices(user_id).await?; + let devices = self.get_readonly_devices_unfiltered(user_id).await?; let own_identity = self.inner.get_user_identity(&self.user_id).await?.and_then(|i| i.own().cloned()); From 39ef1d550d6345f41429b671a51779888b6e4128 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Apr 2022 14:21:31 +0200 Subject: [PATCH 03/10] chore(crypto): Improve a log line --- crates/matrix-sdk-crypto/src/identities/device.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 3103b0575..cce606e84 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -517,10 +517,10 @@ impl ReadOnlyDevice { k } else { warn!( - "Trying to encrypt a Megolm session for user {} on device {}, \ - but the device doesn't have a curve25519 key", - self.user_id(), - self.device_id() + user_id = %self.user_id(), + device_id = %self.device_id(), + "Trying to encrypt a Megolm session, but the device doesn't \ + have a curve25519 key", ); return Err(EventError::MissingSenderKey.into()); }; From 189ae93e90732d1376954032753fef8db23a8559 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Tue, 26 Apr 2022 18:17:04 +0200 Subject: [PATCH 04/10] fix(sdk): Export RumaApiError --- crates/matrix-sdk/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index 4fa36d40a..051a29606 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -60,6 +60,6 @@ pub use account::Account; pub use client::{Client, ClientBuildError, ClientBuilder, LoopCtrl}; #[cfg(feature = "image-proc")] pub use error::ImageError; -pub use error::{Error, HttpError, HttpResult, Result}; +pub use error::{Error, HttpError, HttpResult, Result, RumaApiError}; pub use http_client::HttpSend; pub use room_member::RoomMember; From e1baa257132eed06b5e80d652bbb8d85b6a1c4d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Apr 2022 17:47:39 +0200 Subject: [PATCH 05/10] fix(sled): Don't derive a StoreCipher twice if we're reusing the db --- crates/matrix-sdk-sled/src/cryptostore.rs | 25 +++++++++++++---------- crates/matrix-sdk-sled/src/state_store.rs | 7 ++----- crates/matrix-sdk/src/store.rs | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/crates/matrix-sdk-sled/src/cryptostore.rs b/crates/matrix-sdk-sled/src/cryptostore.rs index 709c681ea..0eaeba84b 100644 --- a/crates/matrix-sdk-sled/src/cryptostore.rs +++ b/crates/matrix-sdk-sled/src/cryptostore.rs @@ -221,13 +221,23 @@ impl SledStore { .open() .map_err(|e| CryptoStoreError::Backend(anyhow!(e)))?; - SledStore::open_helper(db, Some(path), passphrase) + let store_cipher = if let Some(passphrase) = passphrase { + Some(Self::get_or_create_store_cipher(passphrase, &db)?) + } else { + None + } + .into(); + + SledStore::open_helper(db, Some(path), store_cipher) } /// Create a sled based cryptostore using the given sled database. /// The given passphrase will be used to encrypt private data. - pub fn open_with_database(db: Db, passphrase: Option<&str>) -> Result { - SledStore::open_helper(db, None, passphrase) + pub fn open_with_database( + db: Db, + store_cipher: Arc>, + ) -> Result { + SledStore::open_helper(db, None, store_cipher) } fn get_account_info(&self) -> Option { @@ -347,7 +357,7 @@ impl SledStore { fn open_helper( db: Db, path: Option, - passphrase: Option<&str>, + store_cipher: Arc>, ) -> Result { let account = db.open_tree("account")?; let private_identity = db.open_tree("private_identity")?; @@ -369,13 +379,6 @@ impl SledStore { let session_cache = SessionStore::new(); - let store_cipher = if let Some(passphrase) = passphrase { - Some(Self::get_or_create_store_cipher(passphrase, &db)?) - } else { - None - } - .into(); - let database = Self { account_info: RwLock::new(None).into(), path, diff --git a/crates/matrix-sdk-sled/src/state_store.rs b/crates/matrix-sdk-sled/src/state_store.rs index 58eeae824..f49927229 100644 --- a/crates/matrix-sdk-sled/src/state_store.rs +++ b/crates/matrix-sdk-sled/src/state_store.rs @@ -295,11 +295,8 @@ impl SledStore { /// /// The given passphrase will be used to encrypt private data. #[cfg(feature = "crypto-store")] - pub fn open_crypto_store( - &self, - passphrase: Option<&str>, - ) -> Result { - CryptoStore::open_with_database(self.inner.clone(), passphrase) + pub fn open_crypto_store(&self) -> Result { + CryptoStore::open_with_database(self.inner.clone(), self.store_cipher.clone()) } fn serialize_event(&self, event: &impl Serialize) -> Result, SledStoreError> { diff --git a/crates/matrix-sdk/src/store.rs b/crates/matrix-sdk/src/store.rs index 3ed61c9e5..91101a8f2 100644 --- a/crates/matrix-sdk/src/store.rs +++ b/crates/matrix-sdk/src/store.rs @@ -80,11 +80,11 @@ fn open_stores_with_path( ) -> Result<(Box, Box), OpenStoreError> { if let Some(passphrase) = passphrase { let state_store = StateStore::open_with_passphrase(path, passphrase)?; - let crypto_store = state_store.open_crypto_store(Some(passphrase))?; + let crypto_store = state_store.open_crypto_store()?; Ok((Box::new(state_store), Box::new(crypto_store))) } else { let state_store = StateStore::open_with_path(path)?; - let crypto_store = state_store.open_crypto_store(None)?; + let crypto_store = state_store.open_crypto_store()?; Ok((Box::new(state_store), Box::new(crypto_store))) } } From 742c1944eb43a4b69e1783c0bd6d594a2e369488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Apr 2022 18:06:39 +0200 Subject: [PATCH 06/10] fix(sled): Put the StoreCipher behind the Arc, not the Option --- crates/matrix-sdk-sled/src/cryptostore.rs | 17 ++++++++--------- crates/matrix-sdk-sled/src/state_store.rs | 17 +++++++++-------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/crates/matrix-sdk-sled/src/cryptostore.rs b/crates/matrix-sdk-sled/src/cryptostore.rs index 0eaeba84b..2232ef972 100644 --- a/crates/matrix-sdk-sled/src/cryptostore.rs +++ b/crates/matrix-sdk-sled/src/cryptostore.rs @@ -171,7 +171,7 @@ struct TrackedUser { #[derive(Clone)] pub struct SledStore { account_info: Arc>>, - store_cipher: Arc>, + store_cipher: Option>, path: Option, inner: Db, @@ -222,11 +222,10 @@ impl SledStore { .map_err(|e| CryptoStoreError::Backend(anyhow!(e)))?; let store_cipher = if let Some(passphrase) = passphrase { - Some(Self::get_or_create_store_cipher(passphrase, &db)?) + Some(Self::get_or_create_store_cipher(passphrase, &db)?.into()) } else { None - } - .into(); + }; SledStore::open_helper(db, Some(path), store_cipher) } @@ -235,7 +234,7 @@ impl SledStore { /// The given passphrase will be used to encrypt private data. pub fn open_with_database( db: Db, - store_cipher: Arc>, + store_cipher: Option>, ) -> Result { SledStore::open_helper(db, None, store_cipher) } @@ -245,7 +244,7 @@ impl SledStore { } fn serialize_value(&self, event: &impl Serialize) -> Result, CryptoStoreError> { - if let Some(key) = &*self.store_cipher { + if let Some(key) = &self.store_cipher { key.encrypt_value(event).map_err(|e| CryptoStoreError::Backend(anyhow!(e))) } else { Ok(serde_json::to_vec(event)?) @@ -256,7 +255,7 @@ impl SledStore { &self, event: &[u8], ) -> Result { - if let Some(key) = &*self.store_cipher { + if let Some(key) = &self.store_cipher { key.decrypt_value(event).map_err(|e| CryptoStoreError::Backend(anyhow!(e))) } else { Ok(serde_json::from_slice(event)?) @@ -264,7 +263,7 @@ impl SledStore { } fn encode_key(&self, table_name: &str, key: T) -> Vec { - if let Some(store_cipher) = &*self.store_cipher { + if let Some(store_cipher) = &self.store_cipher { key.encode_secure(table_name, store_cipher).to_vec() } else { key.encode() @@ -357,7 +356,7 @@ impl SledStore { fn open_helper( db: Db, path: Option, - store_cipher: Arc>, + store_cipher: Option>, ) -> Result { let account = db.open_tree("account")?; let private_identity = db.open_tree("private_identity")?; diff --git a/crates/matrix-sdk-sled/src/state_store.rs b/crates/matrix-sdk-sled/src/state_store.rs index f49927229..34658e56f 100644 --- a/crates/matrix-sdk-sled/src/state_store.rs +++ b/crates/matrix-sdk-sled/src/state_store.rs @@ -143,7 +143,7 @@ type Result = std::result::Result; pub struct SledStore { path: Option, pub(crate) inner: Db, - store_cipher: Arc>, + store_cipher: Option>, session: Tree, account_data: Tree, members: Tree, @@ -181,7 +181,7 @@ impl SledStore { fn open_helper( db: Db, path: Option, - store_cipher: Option, + store_cipher: Option>, ) -> Result { let session = db.open_tree(SESSION)?; let account_data = db.open_tree(ACCOUNT_DATA)?; @@ -215,7 +215,7 @@ impl SledStore { Ok(Self { path, inner: db, - store_cipher: store_cipher.into(), + store_cipher, session, account_data, members, @@ -256,7 +256,7 @@ impl SledStore { SledStore::open_helper( db, None, - Some(StoreCipher::new().expect("can't create store cipher")), + Some(StoreCipher::new().expect("can't create store cipher").into()), ) .map_err(|e| e.into()) } @@ -275,7 +275,8 @@ impl SledStore { let cipher = StoreCipher::new()?; db.insert("store_cipher".encode(), cipher.export(passphrase)?)?; cipher - }; + } + .into(); SledStore::open_helper(db, Some(path), Some(store_cipher)) } @@ -300,7 +301,7 @@ impl SledStore { } fn serialize_event(&self, event: &impl Serialize) -> Result, SledStoreError> { - if let Some(key) = &*self.store_cipher { + if let Some(key) = &self.store_cipher { Ok(key.encrypt_value(event)?) } else { Ok(serde_json::to_vec(event)?) @@ -311,7 +312,7 @@ impl SledStore { &self, event: &[u8], ) -> Result { - if let Some(key) = &*self.store_cipher { + if let Some(key) = &self.store_cipher { Ok(key.decrypt_value(event)?) } else { Ok(serde_json::from_slice(event)?) @@ -319,7 +320,7 @@ impl SledStore { } fn encode_key(&self, table_name: &str, key: T) -> Vec { - if let Some(store_cipher) = &*self.store_cipher { + if let Some(store_cipher) = &self.store_cipher { key.encode_secure(table_name, store_cipher).to_vec() } else { key.encode() From 6e89ecfbf0bac1bc99bee26604faf4d1c3d30b04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Tue, 26 Apr 2022 18:34:44 +0200 Subject: [PATCH 07/10] fix(indexeddb): Don't derive a StoreCipher twice if we're reusing the db --- .../matrix-sdk-indexeddb/src/cryptostore.rs | 19 +++++++++++-------- crates/matrix-sdk-indexeddb/src/lib.rs | 3 ++- .../matrix-sdk-indexeddb/src/state_store.rs | 8 ++++---- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/crates/matrix-sdk-indexeddb/src/cryptostore.rs b/crates/matrix-sdk-indexeddb/src/cryptostore.rs index fca9f0df9..b1d9b4979 100644 --- a/crates/matrix-sdk-indexeddb/src/cryptostore.rs +++ b/crates/matrix-sdk-indexeddb/src/cryptostore.rs @@ -73,7 +73,7 @@ pub struct IndexeddbStore { name: String, pub(crate) inner: IdbDatabase, - store_cipher: Arc>, + store_cipher: Option>, session_cache: SessionStore, tracked_users_cache: Arc>, @@ -133,7 +133,10 @@ pub struct AccountInfo { } impl IndexeddbStore { - async fn open_helper(prefix: String, store_cipher: Option) -> Result { + pub(crate) async fn open_with_store_cipher( + prefix: String, + store_cipher: Option>, + ) -> Result { let name = format!("{:0}::matrix-sdk-crypto", prefix); // Open my_db v1 @@ -167,7 +170,7 @@ impl IndexeddbStore { name, session_cache, inner: db, - store_cipher: store_cipher.into(), + store_cipher, account_info: RwLock::new(None).into(), tracked_users_cache: DashSet::new().into(), users_for_key_query_cache: DashSet::new().into(), @@ -176,7 +179,7 @@ impl IndexeddbStore { /// Open a new IndexeddbStore with default name and no passphrase pub async fn open() -> Result { - IndexeddbStore::open_helper("crypto".to_owned(), None).await + IndexeddbStore::open_with_store_cipher("crypto".to_owned(), None).await } /// Open a new IndexeddbStore with given name and passphrase @@ -229,16 +232,16 @@ impl IndexeddbStore { } }; - IndexeddbStore::open_helper(prefix, Some(store_cipher)).await + IndexeddbStore::open_with_store_cipher(prefix, Some(store_cipher.into())).await } /// Open a new IndexeddbStore with given name and no passphrase pub async fn open_with_name(name: String) -> Result { - IndexeddbStore::open_helper(name, None).await + IndexeddbStore::open_with_store_cipher(name, None).await } fn serialize_value(&self, value: &impl Serialize) -> Result { - if let Some(key) = &*self.store_cipher { + if let Some(key) = &self.store_cipher { let value = key.encrypt_value(value).map_err(|e| CryptoStoreError::Backend(anyhow!(e)))?; @@ -252,7 +255,7 @@ impl IndexeddbStore { &self, value: JsValue, ) -> Result { - if let Some(key) = &*self.store_cipher { + if let Some(key) = &self.store_cipher { let value: Vec = value.into_serde()?; key.decrypt_value(&value).map_err(|e| CryptoStoreError::Backend(anyhow!(e))) } else { diff --git a/crates/matrix-sdk-indexeddb/src/lib.rs b/crates/matrix-sdk-indexeddb/src/lib.rs index af337a3b6..71c302c53 100644 --- a/crates/matrix-sdk-indexeddb/src/lib.rs +++ b/crates/matrix-sdk-indexeddb/src/lib.rs @@ -33,7 +33,8 @@ async fn open_stores_with_name( if let Some(passphrase) = passphrase { let state_store = StateStore::open_with_passphrase(name.clone(), passphrase).await?; - let crypto_store = CryptoStore::open_with_passphrase(name, passphrase).await?; + let crypto_store = + CryptoStore::open_with_store_cipher(name, state_store.store_cipher.clone()).await?; Ok((Box::new(state_store), Box::new(crypto_store))) } else { let state_store = StateStore::open_with_name(name.clone()).await?; diff --git a/crates/matrix-sdk-indexeddb/src/state_store.rs b/crates/matrix-sdk-indexeddb/src/state_store.rs index a79806fbb..404446a3f 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeSet; +use std::{collections::BTreeSet, sync::Arc}; use anyhow::anyhow; use futures_util::stream; @@ -140,7 +140,7 @@ mod KEYS { pub struct IndexeddbStore { name: String, pub(crate) inner: IdbDatabase, - store_cipher: Option, + pub(crate) store_cipher: Option>, } impl std::fmt::Debug for IndexeddbStore { @@ -152,7 +152,7 @@ impl std::fmt::Debug for IndexeddbStore { type Result = std::result::Result; impl IndexeddbStore { - async fn open_helper(name: String, store_cipher: Option) -> Result { + async fn open_helper(name: String, store_cipher: Option>) -> Result { // Open my_db v1 let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?; db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { @@ -245,7 +245,7 @@ impl IndexeddbStore { tx.await.into_result()?; - IndexeddbStore::open_helper(name, Some(cipher)).await + IndexeddbStore::open_helper(name, Some(cipher.into())).await } pub async fn open_with_name(name: String) -> StoreResult { From dfdaea786b9f04c2b189db5d2fffb08b0ea4f739 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Tue, 26 Apr 2022 18:49:21 +0200 Subject: [PATCH 08/10] docs: Document RumaApiError variants --- crates/matrix-sdk/src/error.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/matrix-sdk/src/error.rs b/crates/matrix-sdk/src/error.rs index 84d047478..91a708945 100644 --- a/crates/matrix-sdk/src/error.rs +++ b/crates/matrix-sdk/src/error.rs @@ -47,8 +47,11 @@ pub type HttpResult = std::result::Result; /// representation if the endpoint is from that. #[derive(Error, Debug)] pub enum RumaApiError { + /// A client API response error. #[error(transparent)] ClientApi(ruma::api::client::Error), + + /// Another API reponse error. #[error(transparent)] Other(ruma::api::error::MatrixError), } From d62cb9650f53ecc4d19640121c952f39f6316b51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Tue, 26 Apr 2022 18:59:51 +0200 Subject: [PATCH 09/10] Fix typo --- crates/matrix-sdk/src/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/matrix-sdk/src/error.rs b/crates/matrix-sdk/src/error.rs index 91a708945..8c611f131 100644 --- a/crates/matrix-sdk/src/error.rs +++ b/crates/matrix-sdk/src/error.rs @@ -51,7 +51,7 @@ pub enum RumaApiError { #[error(transparent)] ClientApi(ruma::api::client::Error), - /// Another API reponse error. + /// Another API response error. #[error(transparent)] Other(ruma::api::error::MatrixError), } From 910bf531feeda7d1b86dba35889fe4ef8b69c1bc Mon Sep 17 00:00:00 2001 From: Julian Sparber Date: Wed, 27 Apr 2022 12:53:51 +0200 Subject: [PATCH 10/10] feat(sdk): Add method to set whether a room is DM and store all targets This ensures also that we use, for user verification, only a DM room with no members other then ourself and the user to be verified. --- crates/matrix-sdk-base/src/client.rs | 7 ++- crates/matrix-sdk-base/src/rooms/mod.rs | 10 ++-- crates/matrix-sdk-base/src/rooms/normal.rs | 15 +++-- .../src/encryption/identities/users.rs | 9 +++ crates/matrix-sdk/src/encryption/mod.rs | 9 +-- crates/matrix-sdk/src/room/common.rs | 60 +++++++++++++++++-- 6 files changed, 88 insertions(+), 22 deletions(-) diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 503766f97..59f368496 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -497,11 +497,12 @@ impl BaseClient { ); if let Some(room) = changes.room_infos.get_mut(room_id) { - room.base_info.dm_target = Some(user_id.clone()); + room.base_info.dm_targets.insert(user_id.clone()); } else if let Some(room) = self.store.get_room(room_id) { let mut info = room.clone_info(); - info.base_info.dm_target = Some(user_id.clone()); - changes.add_room(info); + if info.base_info.dm_targets.insert(user_id.clone()) { + changes.add_room(info); + } } } } diff --git a/crates/matrix-sdk-base/src/rooms/mod.rs b/crates/matrix-sdk-base/src/rooms/mod.rs index ca241bf76..bd2f0e85c 100644 --- a/crates/matrix-sdk-base/src/rooms/mod.rs +++ b/crates/matrix-sdk-base/src/rooms/mod.rs @@ -1,7 +1,7 @@ mod members; mod normal; -use std::cmp::max; +use std::{cmp::max, collections::HashSet}; pub use members::RoomMember; pub use normal::{Room, RoomInfo, RoomType}; @@ -29,9 +29,9 @@ pub struct BaseRoomInfo { pub(crate) canonical_alias: Option, /// The `m.room.create` event content of this room. pub(crate) create: Option, - /// The user id this room is sharing the direct message with, if the room is - /// a direct message. - pub(crate) dm_target: Option, + /// A list of user ids this room is considered as direct message, if this + /// room is a DM. + pub(crate) dm_targets: HashSet, /// The `m.room.encryption` event content that enabled E2EE in this room. pub(crate) encryption: Option, /// The guest access policy of this room. @@ -183,7 +183,7 @@ impl Default for BaseRoomInfo { avatar_url: None, canonical_alias: None, create: None, - dm_target: None, + dm_targets: Default::default(), encryption: None, guest_access: GuestAccess::Forbidden, history_visibility: HistoryVisibility::WorldReadable, diff --git a/crates/matrix-sdk-base/src/rooms/normal.rs b/crates/matrix-sdk-base/src/rooms/normal.rs index 969afed7b..53a5cee5f 100644 --- a/crates/matrix-sdk-base/src/rooms/normal.rs +++ b/crates/matrix-sdk-base/src/rooms/normal.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::{Arc, RwLock as SyncRwLock}; +use std::{ + collections::HashSet, + sync::{Arc, RwLock as SyncRwLock}, +}; use dashmap::DashSet; use futures_channel::mpsc; @@ -184,17 +187,17 @@ impl Room { /// Is this room considered a direct message. pub fn is_direct(&self) -> bool { - self.inner.read().unwrap().base_info.dm_target.is_some() + !self.inner.read().unwrap().base_info.dm_targets.is_empty() } - /// If this room is a direct message, get the member that we're sharing the + /// If this room is a direct message, get the members that we're sharing the /// room with. /// /// *Note*: The member list might have been modified in the meantime and - /// the target might not even be in the room anymore. This setting should + /// the targets might not even be in the room anymore. This setting should /// only be considered as guidance. - pub fn direct_target(&self) -> Option { - self.inner.read().unwrap().base_info.dm_target.clone() + pub fn direct_targets(&self) -> HashSet { + self.inner.read().unwrap().base_info.dm_targets.clone() } /// Is the room encrypted. diff --git a/crates/matrix-sdk/src/encryption/identities/users.rs b/crates/matrix-sdk/src/encryption/identities/users.rs index cf06ad262..2886534a7 100644 --- a/crates/matrix-sdk/src/encryption/identities/users.rs +++ b/crates/matrix-sdk/src/encryption/identities/users.rs @@ -465,6 +465,15 @@ impl OtherUserIdentity { let content = self.inner.verification_request_content(methods.clone()).await; let room = if let Some(room) = self.direct_message_room.read().await.as_ref() { + // Make sure that the user, to be verified, is still in the room + if !room + .active_members() + .await? + .iter() + .any(|member| member.user_id() == self.inner.user_id()) + { + room.invite_user_by_id(self.inner.user_id()).await?; + } room.clone() } else if let Some(room) = self.client.create_dm_room(self.inner.user_id().to_owned()).await? diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index f32195580..fa75cc35b 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -404,11 +404,12 @@ impl Client { #[cfg(feature = "encryption")] fn get_dm_room(&self, user_id: &UserId) -> Option { let rooms = self.joined_rooms(); - let room_pairs: Vec<_> = - rooms.iter().map(|r| (r.room_id().to_owned(), r.direct_target())).collect(); - trace!(rooms = ?room_pairs, "Finding direct room"); - let room = rooms.into_iter().find(|r| r.direct_target().as_deref() == Some(user_id)); + // Find the room we share with the `user_id` and only with `user_id` + let room = rooms.into_iter().find(|r| { + let targets = r.direct_targets(); + targets.len() == 1 && targets.contains(user_id) + }); trace!(?room, "Found room"); room diff --git a/crates/matrix-sdk/src/room/common.rs b/crates/matrix-sdk/src/room/common.rs index 99f631f35..2820907ae 100644 --- a/crates/matrix-sdk/src/room/common.rs +++ b/crates/matrix-sdk/src/room/common.rs @@ -1,4 +1,4 @@ -use std::{ops::Deref, sync::Arc}; +use std::{collections::BTreeMap, ops::Deref, sync::Arc}; use futures_core::stream::Stream; use matrix_sdk_base::{ @@ -8,6 +8,7 @@ use matrix_sdk_base::{ use matrix_sdk_common::locks::Mutex; use ruma::{ api::client::{ + config::set_global_account_data, filter::{LazyLoadOptions, RoomEventFilter}, membership::{get_member_events, join_room_by_id, leave_room}, message::get_message_events::{self, v3::Direction}, @@ -16,10 +17,11 @@ use ruma::{ }, assign, events::{ + direct::DirectEvent, room::{history_visibility::HistoryVisibility, MediaSource}, tag::{TagInfo, TagName}, - AnyRoomAccountDataEvent, AnyStateEvent, AnySyncStateEvent, RedactContent, - RedactedEventContent, RoomAccountDataEvent, RoomAccountDataEventContent, + AnyRoomAccountDataEvent, AnyStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, + RedactContent, RedactedEventContent, RoomAccountDataEvent, RoomAccountDataEventContent, RoomAccountDataEventType, StateEventContent, StateEventType, StaticEventContent, SyncStateEvent, }, @@ -30,7 +32,7 @@ use ruma::{ use crate::{ media::{MediaFormat, MediaRequest}, room::RoomType, - BaseRoom, Client, HttpError, HttpResult, Result, RoomMember, + BaseRoom, Client, Error, HttpError, HttpResult, Result, RoomMember, }; /// A struct containing methods that are common for Joined, Invited and Left @@ -817,6 +819,56 @@ impl Common { let request = delete_tag::v3::Request::new(&user_id, self.inner.room_id(), tag.as_ref()); self.client.send(request, None).await } + + /// Sets whether this room is a DM. + /// + /// When setting this room as DM, it will be marked as DM for all active + /// members of the room. When unsetting this room as DM, it will be + /// unmarked as DM for all users, not just the members. + /// + /// # Arguments + /// * `is_direct` - Whether to mark this room as direct. + pub async fn set_is_direct(&self, is_direct: bool) -> Result<()> { + let user_id = self + .client + .user_id() + .await + .ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?; + + let mut content = self + .client + .store() + .get_account_data_event(GlobalAccountDataEventType::Direct) + .await? + .map(|e| e.deserialize_as::()) + .transpose()? + .map(|e| e.content) + .unwrap_or_else(|| ruma::events::direct::DirectEventContent(BTreeMap::new())); + + let this_room_id = self.inner.room_id(); + + if is_direct { + let room_members = self.active_members().await?; + for member in room_members { + let entry = content.entry(member.user_id().to_owned()).or_default(); + if !entry.iter().any(|room_id| room_id == this_room_id) { + entry.push(this_room_id.to_owned()); + } + } + } else { + for (_, list) in content.iter_mut() { + list.retain(|room_id| *room_id != this_room_id); + } + + // Remove user ids that don't have any room marked as DM + content.retain(|_, list| !list.is_empty()); + } + + let request = set_global_account_data::v3::Request::new(&content, &user_id)?; + + self.client.send(request, None).await?; + Ok(()) + } } /// Options for [`messages`][Common::messages].