Merge branch 'main' into gnunicorn/issue133

This commit is contained in:
Benjamin Kampmann
2022-04-27 15:45:39 +02:00
committed by GitHub
22 changed files with 253 additions and 117 deletions

View File

@@ -51,7 +51,7 @@ fn huge_keys_query_response() -> get_keys::v3::Response {
pub fn keys_query(c: &mut Criterion) {
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let machine = OlmMachine::new(alice_id(), alice_device_id());
let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id()));
let response = keys_query_response();
let txn_id = TransactionId::new();
@@ -101,7 +101,7 @@ pub fn keys_claiming(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::new("memory store", &name), &response, |b, response| {
b.iter_batched(
|| {
let machine = OlmMachine::new(alice_id(), alice_device_id());
let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id()));
runtime
.block_on(machine.mark_request_as_sent(&txn_id, &keys_query_response))
.unwrap();
@@ -151,7 +151,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
let count = response.one_time_keys.values().fold(0, |acc, d| acc + d.len());
let machine = OlmMachine::new(alice_id(), alice_device_id());
let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id()));
runtime.block_on(machine.mark_request_as_sent(&txn_id, &keys_query_response)).unwrap();
runtime.block_on(machine.mark_request_as_sent(&txn_id, &response)).unwrap();
@@ -214,7 +214,7 @@ pub fn room_key_sharing(c: &mut Criterion) {
pub fn devices_missing_sessions_collecting(c: &mut Criterion) {
let runtime = Builder::new_multi_thread().build().expect("Can't create runtime");
let machine = OlmMachine::new(alice_id(), alice_device_id());
let machine = runtime.block_on(OlmMachine::new(alice_id(), alice_device_id()));
let response = huge_keys_query_response();
let txn_id = TransactionId::new();
let users: Vec<OwnedUserId> = response.device_keys.keys().cloned().collect();

View File

@@ -497,11 +497,12 @@ impl BaseClient {
);
if let Some(room) = changes.room_infos.get_mut(room_id) {
room.base_info.dm_target = Some(user_id.clone());
room.base_info.dm_targets.insert(user_id.clone());
} else if let Some(room) = self.store.get_room(room_id) {
let mut info = room.clone_info();
info.base_info.dm_target = Some(user_id.clone());
changes.add_room(info);
if info.base_info.dm_targets.insert(user_id.clone()) {
changes.add_room(info);
}
}
}
}

View File

@@ -1,7 +1,7 @@
mod members;
mod normal;
use std::{cmp::max, fmt};
use std::{cmp::max, fmt, collections::HashSet};
pub use members::RoomMember;
pub use normal::{Room, RoomInfo, RoomType};
@@ -59,9 +59,9 @@ pub struct BaseRoomInfo {
pub(crate) canonical_alias: Option<OwnedRoomAliasId>,
/// The `m.room.create` event content of this room.
pub(crate) create: Option<RoomCreateEventContent>,
/// The user id this room is sharing the direct message with, if the room is
/// a direct message.
pub(crate) dm_target: Option<OwnedUserId>,
/// A list of user ids this room is considered as direct message, if this
/// room is a DM.
pub(crate) dm_targets: HashSet<OwnedUserId>,
/// The `m.room.encryption` event content that enabled E2EE in this room.
pub(crate) encryption: Option<RoomEncryptionEventContent>,
/// The guest access policy of this room.
@@ -213,7 +213,7 @@ impl Default for BaseRoomInfo {
avatar_url: None,
canonical_alias: None,
create: None,
dm_target: None,
dm_targets: Default::default(),
encryption: None,
guest_access: GuestAccess::Forbidden,
history_visibility: HistoryVisibility::WorldReadable,

View File

@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::{Arc, RwLock as SyncRwLock};
use std::{
collections::HashSet,
sync::{Arc, RwLock as SyncRwLock},
};
use dashmap::DashSet;
use futures_channel::mpsc;
@@ -184,17 +187,17 @@ impl Room {
/// Is this room considered a direct message.
pub fn is_direct(&self) -> bool {
self.inner.read().unwrap().base_info.dm_target.is_some()
!self.inner.read().unwrap().base_info.dm_targets.is_empty()
}
/// If this room is a direct message, get the member that we're sharing the
/// If this room is a direct message, get the members that we're sharing the
/// room with.
///
/// *Note*: The member list might have been modified in the meantime and
/// the target might not even be in the room anymore. This setting should
/// the targets might not even be in the room anymore. This setting should
/// only be considered as guidance.
pub fn direct_target(&self) -> Option<OwnedUserId> {
self.inner.read().unwrap().base_info.dm_target.clone()
pub fn direct_targets(&self) -> HashSet<OwnedUserId> {
self.inner.read().unwrap().base_info.dm_targets.clone()
}
/// Is the room encrypted.

View File

@@ -28,7 +28,7 @@ use ruma::{
#[tokio::main]
async fn main() -> Result<(), OlmError> {
let alice = user_id!("@alice:example.org");
let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
let to_device_events = ToDevice::default();
let changed_devices = DeviceLists::default();
@@ -69,4 +69,4 @@ The following crate feature flags are available:
* `qrcode`: Enbles QRcode generation and reading code
* `testing`: provides facilities and functions for tests, in particular for integration testing store implementations. ATTENTION: do not ever use outside of tests, we do not provide any stability warantees on these, these are merely helpers. If you find you _need_ any function provided here outside of tests, please open a Github Issue and inform us about your use case for us to consider.
* `testing`: provides facilities and functions for tests, in particular for integration testing store implementations. ATTENTION: do not ever use outside of tests, we do not provide any stability warantees on these, these are merely helpers. If you find you _need_ any function provided here outside of tests, please open a Github Issue and inform us about your use case for us to consider.

View File

@@ -459,7 +459,7 @@ mod tests {
#[async_test]
async fn memory_store_backups() -> Result<(), OlmError> {
let machine = OlmMachine::new(alice_id(), alice_device_id());
let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
backup_flow(machine).await
}

View File

@@ -80,8 +80,8 @@ pub enum KeyExportError {
/// # use ruma::{device_id, user_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
/// # block_on(async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// # let export = Cursor::new("".to_owned());
/// let exported_keys = decrypt_key_export(export, "1234").unwrap();
/// machine.import_keys(exported_keys, false, |_, _| {}).await.unwrap();
@@ -137,8 +137,8 @@ pub fn decrypt_key_export(
/// # use ruma::{device_id, user_id, room_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
/// # block_on(async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// let room_id = room_id!("!test:localhost");
/// let exported_keys = machine.export_keys(|s| s.room_id() == room_id).await.unwrap();
/// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1);

View File

@@ -40,6 +40,8 @@ use tracing::warn;
use vodozemac::{Curve25519PublicKey, Ed25519PublicKey};
use super::{atomic_bool_deserializer, atomic_bool_serializer};
#[cfg(any(test, feature = "testing"))]
use crate::OlmMachine;
use crate::{
error::{EventError, OlmError, OlmResult, SignatureError},
identities::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities},
@@ -47,10 +49,8 @@ use crate::{
store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult},
types::{DeviceKey, DeviceKeys, SignedKey},
verification::VerificationMachine,
OutgoingVerificationRequest, Sas, ToDeviceRequest, VerificationRequest,
OutgoingVerificationRequest, ReadOnlyAccount, Sas, ToDeviceRequest, VerificationRequest,
};
#[cfg(any(test, feature = "testing"))]
use crate::{OlmMachine, ReadOnlyAccount};
/// A read-only version of a `Device`.
#[derive(Clone, Serialize, Deserialize)]
@@ -517,10 +517,10 @@ impl ReadOnlyDevice {
k
} else {
warn!(
"Trying to encrypt a Megolm session for user {} on device {}, \
but the device doesn't have a curve25519 key",
self.user_id(),
self.device_id()
user_id = %self.user_id(),
device_id = %self.device_id(),
"Trying to encrypt a Megolm session, but the device doesn't \
have a curve25519 key",
);
return Err(EventError::MissingSenderKey.into());
};
@@ -601,14 +601,22 @@ impl ReadOnlyDevice {
ReadOnlyDevice::from_account(machine.account()).await
}
#[cfg(any(test, feature = "testing"))]
#[allow(dead_code)]
/// Generate the Device from the reference of an Account.
/// Create a `ReadOnlyDevice` from an `Account`
///
/// TESTING FACILITY ONLY, DO NOT USE OUTSIDE OF TESTS
/// We will have our own device in the store once we receive a keys/query
/// response, but this is useful to create it before we receive such a
/// response.
///
/// It also makes it easier to check that the server doesn't lie about our
/// own device.
///
/// *Don't* use this after we received a keys/query response, other
/// users/devices might add signatures to our own device, which can't be
/// replicated locally.
pub async fn from_account(account: &ReadOnlyAccount) -> ReadOnlyDevice {
let device_keys = account.device_keys().await;
ReadOnlyDevice::try_from(&device_keys).unwrap()
ReadOnlyDevice::try_from(&device_keys)
.expect("Creating a device from our own account should always succeed")
}
}

View File

@@ -67,7 +67,7 @@ use crate::{
SecretImportError, Store,
},
verification::{Verification, VerificationMachine, VerificationRequest},
CrossSigningKeyExport, RoomKeyImportResult, ToDeviceRequest,
CrossSigningKeyExport, ReadOnlyDevice, RoomKeyImportResult, ToDeviceRequest,
};
/// State machine implementation of the Olm/Megolm encryption protocol used for
@@ -92,7 +92,7 @@ pub struct OlmMachine {
/// A state machine that handles Olm sessions creation.
session_manager: SessionManager,
/// A state machine that keeps track of our outbound group sessions.
group_session_manager: GroupSessionManager,
pub(crate) group_session_manager: GroupSessionManager,
/// A state machine that is responsible to handle and keep track of SAS
/// verification flows.
verification_machine: VerificationMachine,
@@ -128,17 +128,12 @@ impl OlmMachine {
/// * `user_id` - The unique id of the user that owns this machine.
///
/// * `device_id` - The unique id of the device that owns this machine.
pub fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
pub async fn new(user_id: &UserId, device_id: &DeviceId) -> Self {
let store: Box<dyn CryptoStore> = Box::new(MemoryStore::new());
let account = ReadOnlyAccount::new(user_id, device_id);
OlmMachine::new_helper(
user_id,
device_id,
store,
account,
PrivateCrossSigningIdentity::empty(user_id),
)
OlmMachine::with_store(user_id, device_id, store)
.await
.expect("Reading and writing to the memory store always succeeds")
}
fn new_helper(
@@ -233,11 +228,18 @@ impl OlmMachine {
}
None => {
let account = ReadOnlyAccount::new(user_id, device_id);
let device = ReadOnlyDevice::from_account(&account).await;
debug!(
ed25519_key = account.identity_keys().ed25519.to_base64().as_str(),
"Created a new Olm account"
);
store.save_account(account.clone()).await?;
let changes = Changes {
account: Some(account.clone()),
devices: DeviceChanges { new: vec![device], ..Default::default() },
..Default::default()
};
store.save_changes(changes).await?;
account
}
};
@@ -1178,8 +1180,8 @@ impl OlmMachine {
/// # use ruma::{device_id, user_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org").to_owned();
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
/// # block_on(async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// let device = machine.get_device(&alice, device_id!("DEVICEID")).await;
///
/// println!("{:?}", device);
@@ -1219,8 +1221,8 @@ impl OlmMachine {
/// # use ruma::{device_id, user_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org").to_owned();
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
/// # block_on(async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// let devices = machine.get_user_devices(&alice).await.unwrap();
///
/// for device in devices.devices() {
@@ -1255,8 +1257,8 @@ impl OlmMachine {
/// # use ruma::{device_id, user_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
/// # block_on(async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// # let export = Cursor::new("".to_owned());
/// let exported_keys = decrypt_key_export(export, "1234").unwrap();
/// machine.import_keys(exported_keys, false, |_, _| {}).await.unwrap();
@@ -1367,8 +1369,8 @@ impl OlmMachine {
/// # use ruma::{device_id, user_id, room_id};
/// # use futures::executor::block_on;
/// # let alice = user_id!("@alice:example.org");
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID"));
/// # block_on(async {
/// # let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await;
/// let room_id = room_id!("!test:localhost");
/// let exported_keys = machine.export_keys(|s| s.room_id() == room_id).await.unwrap();
/// let encrypted_export = encrypt_key_export(&exported_keys, "1234", 1);
@@ -1598,7 +1600,7 @@ pub(crate) mod tests {
}
pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
machine.account.inner.update_uploaded_key_count(0);
let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload");
let response = keys_upload_response();
@@ -1621,7 +1623,7 @@ pub(crate) mod tests {
let alice_id = alice_id();
let alice_device = alice_device_id();
let alice = OlmMachine::new(alice_id, alice_device);
let alice = OlmMachine::new(alice_id, alice_device).await;
let alice_device = ReadOnlyDevice::from_machine(&alice).await;
let bob_device = ReadOnlyDevice::from_machine(&bob).await;
@@ -1672,13 +1674,13 @@ pub(crate) mod tests {
#[async_test]
async fn create_olm_machine() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
assert!(!machine.account().shared());
}
#[async_test]
async fn generate_one_time_keys() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
assert!(machine.account.generate_one_time_keys().await.is_some());
@@ -1694,7 +1696,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_device_key_signing() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let mut device_keys = machine.account.device_keys().await;
let identity_keys = machine.account.identity_keys();
@@ -1710,7 +1712,7 @@ pub(crate) mod tests {
#[async_test]
async fn tests_session_invalidation() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let room_id = room_id!("!test:example.org");
machine.create_outbound_group_session_with_defaults(room_id).await.unwrap();
@@ -1727,7 +1729,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_invalid_signature() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
let mut device_keys = machine.account.device_keys().await;
@@ -1743,7 +1745,7 @@ pub(crate) mod tests {
#[async_test]
async fn one_time_key_signing() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
machine.account.inner.update_uploaded_key_count(49);
let mut one_time_keys = machine.account.signed_one_time_keys().await;
@@ -1763,7 +1765,7 @@ pub(crate) mod tests {
#[async_test]
async fn test_keys_for_upload() {
let machine = OlmMachine::new(user_id(), alice_device_id());
let machine = OlmMachine::new(user_id(), alice_device_id()).await;
machine.account.inner.update_uploaded_key_count(0);
let ed25519_key = machine.account.identity_keys().ed25519;

View File

@@ -359,7 +359,7 @@ impl GroupSessionManager {
let mut should_rotate = user_left || visibility_changed;
for user_id in users {
let user_devices = self.store.get_user_devices(user_id).await?;
let user_devices = self.store.get_user_devices_filtered(user_id).await?;
let non_blacklisted_devices: Vec<Device> =
user_devices.devices().filter(|d| !d.is_blacklisted()).collect();
@@ -596,7 +596,9 @@ mod tests {
client::keys::{claim_keys, get_keys},
IncomingResponse,
},
device_id, room_id, user_id, DeviceId, TransactionId, UserId,
device_id,
events::room::history_visibility::HistoryVisibility,
room_id, user_id, DeviceId, TransactionId, UserId,
};
use serde_json::Value;
@@ -626,12 +628,12 @@ mod tests {
.expect("Can't parse the keys upload response")
}
async fn machine() -> OlmMachine {
async fn machine_with_user(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
let keys_query = keys_query_response();
let keys_claim = keys_claim_response();
let txn_id = TransactionId::new();
let machine = OlmMachine::new(alice_id(), alice_device_id());
let machine = OlmMachine::new(user_id, device_id).await;
machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
@@ -639,6 +641,10 @@ mod tests {
machine
}
async fn machine() -> OlmMachine {
machine_with_user(alice_id(), alice_device_id()).await
}
#[async_test]
async fn test_sharing() {
let machine = machine().await;
@@ -659,4 +665,37 @@ mod tests {
// that all 148 valid sessions get an room key.
assert_eq!(event_count, 148);
}
#[async_test]
async fn key_recipient_collecting() {
// The user id comes from the fact that the keys_query.json file uses
// this one.
let user_id = user_id!("@example:localhost");
let device_id = device_id!("TESTDEVICE");
let room_id = room_id!("!test:localhost");
let machine = machine_with_user(user_id, device_id).await;
let (outbound, _) = machine
.group_session_manager
.get_or_create_outbound_session(room_id, EncryptionSettings::default())
.await
.expect("We should be able to create a new session");
let history_visibility = HistoryVisibility::Joined;
let users = [user_id].into_iter();
let (_, recipients) = machine
.group_session_manager
.collect_session_recipients(users, history_visibility, &outbound)
.await
.expect("We should be able to collect the session recipients");
assert!(!recipients.is_empty());
// Make sure that our own device isn't part of the recipients.
assert!(!recipients[user_id]
.iter()
.any(|d| d.user_id() == user_id && d.device_id() == device_id));
}
}

View File

@@ -387,8 +387,22 @@ impl Store {
}
/// Get all devices associated with the given `user_id`
///
/// *Note*: This doesn't return our own device.
pub async fn get_user_devices_filtered(&self, user_id: &UserId) -> Result<UserDevices> {
self.get_user_devices(user_id).await.map(|mut d| {
if user_id == self.user_id() {
d.inner.remove(self.device_id());
}
d
})
}
/// Get all devices associated with the given `user_id`
///
/// *Note*: This does also return our own device.
pub async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
let devices = self.inner.get_user_devices(user_id).await?;
let devices = self.get_readonly_devices_unfiltered(user_id).await?;
let own_identity =
self.inner.get_user_identity(&self.user_id).await?.and_then(|i| i.own().cloned());

View File

@@ -73,7 +73,7 @@ pub struct IndexeddbStore {
name: String,
pub(crate) inner: IdbDatabase,
store_cipher: Arc<Option<StoreCipher>>,
store_cipher: Option<Arc<StoreCipher>>,
session_cache: SessionStore,
tracked_users_cache: Arc<DashSet<OwnedUserId>>,
@@ -133,7 +133,10 @@ pub struct AccountInfo {
}
impl IndexeddbStore {
async fn open_helper(prefix: String, store_cipher: Option<StoreCipher>) -> Result<Self> {
pub(crate) async fn open_with_store_cipher(
prefix: String,
store_cipher: Option<Arc<StoreCipher>>,
) -> Result<Self> {
let name = format!("{:0}::matrix-sdk-crypto", prefix);
// Open my_db v1
@@ -167,7 +170,7 @@ impl IndexeddbStore {
name,
session_cache,
inner: db,
store_cipher: store_cipher.into(),
store_cipher,
account_info: RwLock::new(None).into(),
tracked_users_cache: DashSet::new().into(),
users_for_key_query_cache: DashSet::new().into(),
@@ -176,7 +179,7 @@ impl IndexeddbStore {
/// Open a new IndexeddbStore with default name and no passphrase
pub async fn open() -> Result<Self> {
IndexeddbStore::open_helper("crypto".to_owned(), None).await
IndexeddbStore::open_with_store_cipher("crypto".to_owned(), None).await
}
/// Open a new IndexeddbStore with given name and passphrase
@@ -229,16 +232,16 @@ impl IndexeddbStore {
}
};
IndexeddbStore::open_helper(prefix, Some(store_cipher)).await
IndexeddbStore::open_with_store_cipher(prefix, Some(store_cipher.into())).await
}
/// Open a new IndexeddbStore with given name and no passphrase
pub async fn open_with_name(name: String) -> Result<Self> {
IndexeddbStore::open_helper(name, None).await
IndexeddbStore::open_with_store_cipher(name, None).await
}
fn serialize_value(&self, value: &impl Serialize) -> Result<JsValue, CryptoStoreError> {
if let Some(key) = &*self.store_cipher {
if let Some(key) = &self.store_cipher {
let value =
key.encrypt_value(value).map_err(|e| CryptoStoreError::Backend(anyhow!(e)))?;
@@ -252,7 +255,7 @@ impl IndexeddbStore {
&self,
value: JsValue,
) -> Result<T, CryptoStoreError> {
if let Some(key) = &*self.store_cipher {
if let Some(key) = &self.store_cipher {
let value: Vec<u8> = value.into_serde()?;
key.decrypt_value(&value).map_err(|e| CryptoStoreError::Backend(anyhow!(e)))
} else {

View File

@@ -33,7 +33,8 @@ async fn open_stores_with_name(
if let Some(passphrase) = passphrase {
let state_store = StateStore::open_with_passphrase(name.clone(), passphrase).await?;
let crypto_store = CryptoStore::open_with_passphrase(name, passphrase).await?;
let crypto_store =
CryptoStore::open_with_store_cipher(name, state_store.store_cipher.clone()).await?;
Ok((Box::new(state_store), Box::new(crypto_store)))
} else {
let state_store = StateStore::open_with_name(name.clone()).await?;

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::BTreeSet;
use std::{collections::BTreeSet, sync::Arc};
use anyhow::anyhow;
use futures_util::stream;
@@ -140,7 +140,7 @@ mod KEYS {
pub struct IndexeddbStore {
name: String,
pub(crate) inner: IdbDatabase,
store_cipher: Option<StoreCipher>,
pub(crate) store_cipher: Option<Arc<StoreCipher>>,
}
impl std::fmt::Debug for IndexeddbStore {
@@ -152,7 +152,7 @@ impl std::fmt::Debug for IndexeddbStore {
type Result<A, E = SerializationError> = std::result::Result<A, E>;
impl IndexeddbStore {
async fn open_helper(name: String, store_cipher: Option<StoreCipher>) -> Result<Self> {
async fn open_helper(name: String, store_cipher: Option<Arc<StoreCipher>>) -> Result<Self> {
// Open my_db v1
let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?;
db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> {
@@ -245,7 +245,7 @@ impl IndexeddbStore {
tx.await.into_result()?;
IndexeddbStore::open_helper(name, Some(cipher)).await
IndexeddbStore::open_helper(name, Some(cipher.into())).await
}
pub async fn open_with_name(name: String) -> StoreResult<Self> {

View File

@@ -171,7 +171,7 @@ struct TrackedUser {
#[derive(Clone)]
pub struct SledStore {
account_info: Arc<RwLock<Option<AccountInfo>>>,
store_cipher: Arc<Option<StoreCipher>>,
store_cipher: Option<Arc<StoreCipher>>,
path: Option<PathBuf>,
inner: Db,
@@ -221,13 +221,22 @@ impl SledStore {
.open()
.map_err(|e| CryptoStoreError::Backend(anyhow!(e)))?;
SledStore::open_helper(db, Some(path), passphrase)
let store_cipher = if let Some(passphrase) = passphrase {
Some(Self::get_or_create_store_cipher(passphrase, &db)?.into())
} else {
None
};
SledStore::open_helper(db, Some(path), store_cipher)
}
/// Create a sled based cryptostore using the given sled database.
/// The given passphrase will be used to encrypt private data.
pub fn open_with_database(db: Db, passphrase: Option<&str>) -> Result<Self, OpenStoreError> {
SledStore::open_helper(db, None, passphrase)
pub fn open_with_database(
db: Db,
store_cipher: Option<Arc<StoreCipher>>,
) -> Result<Self, OpenStoreError> {
SledStore::open_helper(db, None, store_cipher)
}
fn get_account_info(&self) -> Option<AccountInfo> {
@@ -235,7 +244,7 @@ impl SledStore {
}
fn serialize_value(&self, event: &impl Serialize) -> Result<Vec<u8>, CryptoStoreError> {
if let Some(key) = &*self.store_cipher {
if let Some(key) = &self.store_cipher {
key.encrypt_value(event).map_err(|e| CryptoStoreError::Backend(anyhow!(e)))
} else {
Ok(serde_json::to_vec(event)?)
@@ -246,7 +255,7 @@ impl SledStore {
&self,
event: &[u8],
) -> Result<T, CryptoStoreError> {
if let Some(key) = &*self.store_cipher {
if let Some(key) = &self.store_cipher {
key.decrypt_value(event).map_err(|e| CryptoStoreError::Backend(anyhow!(e)))
} else {
Ok(serde_json::from_slice(event)?)
@@ -254,7 +263,7 @@ impl SledStore {
}
fn encode_key<T: EncodeKey>(&self, table_name: &str, key: T) -> Vec<u8> {
if let Some(store_cipher) = &*self.store_cipher {
if let Some(store_cipher) = &self.store_cipher {
key.encode_secure(table_name, store_cipher).to_vec()
} else {
key.encode()
@@ -347,7 +356,7 @@ impl SledStore {
fn open_helper(
db: Db,
path: Option<PathBuf>,
passphrase: Option<&str>,
store_cipher: Option<Arc<StoreCipher>>,
) -> Result<Self, OpenStoreError> {
let account = db.open_tree("account")?;
let private_identity = db.open_tree("private_identity")?;
@@ -369,13 +378,6 @@ impl SledStore {
let session_cache = SessionStore::new();
let store_cipher = if let Some(passphrase) = passphrase {
Some(Self::get_or_create_store_cipher(passphrase, &db)?)
} else {
None
}
.into();
let database = Self {
account_info: RwLock::new(None).into(),
path,

View File

@@ -143,7 +143,7 @@ type Result<A, E = SledStoreError> = std::result::Result<A, E>;
pub struct SledStore {
path: Option<PathBuf>,
pub(crate) inner: Db,
store_cipher: Arc<Option<StoreCipher>>,
store_cipher: Option<Arc<StoreCipher>>,
session: Tree,
account_data: Tree,
members: Tree,
@@ -181,7 +181,7 @@ impl SledStore {
fn open_helper(
db: Db,
path: Option<PathBuf>,
store_cipher: Option<StoreCipher>,
store_cipher: Option<Arc<StoreCipher>>,
) -> Result<Self> {
let session = db.open_tree(SESSION)?;
let account_data = db.open_tree(ACCOUNT_DATA)?;
@@ -215,7 +215,7 @@ impl SledStore {
Ok(Self {
path,
inner: db,
store_cipher: store_cipher.into(),
store_cipher,
session,
account_data,
members,
@@ -256,7 +256,7 @@ impl SledStore {
SledStore::open_helper(
db,
None,
Some(StoreCipher::new().expect("can't create store cipher")),
Some(StoreCipher::new().expect("can't create store cipher").into()),
)
.map_err(|e| e.into())
}
@@ -275,7 +275,8 @@ impl SledStore {
let cipher = StoreCipher::new()?;
db.insert("store_cipher".encode(), cipher.export(passphrase)?)?;
cipher
};
}
.into();
SledStore::open_helper(db, Some(path), Some(store_cipher))
}
@@ -295,15 +296,12 @@ impl SledStore {
///
/// The given passphrase will be used to encrypt private data.
#[cfg(feature = "crypto-store")]
pub fn open_crypto_store(
&self,
passphrase: Option<&str>,
) -> Result<CryptoStore, OpenStoreError> {
CryptoStore::open_with_database(self.inner.clone(), passphrase)
pub fn open_crypto_store(&self) -> Result<CryptoStore, OpenStoreError> {
CryptoStore::open_with_database(self.inner.clone(), self.store_cipher.clone())
}
fn serialize_event(&self, event: &impl Serialize) -> Result<Vec<u8>, SledStoreError> {
if let Some(key) = &*self.store_cipher {
if let Some(key) = &self.store_cipher {
Ok(key.encrypt_value(event)?)
} else {
Ok(serde_json::to_vec(event)?)
@@ -314,7 +312,7 @@ impl SledStore {
&self,
event: &[u8],
) -> Result<T, SledStoreError> {
if let Some(key) = &*self.store_cipher {
if let Some(key) = &self.store_cipher {
Ok(key.decrypt_value(event)?)
} else {
Ok(serde_json::from_slice(event)?)
@@ -322,7 +320,7 @@ impl SledStore {
}
fn encode_key<T: EncodeKey>(&self, table_name: &str, key: T) -> Vec<u8> {
if let Some(store_cipher) = &*self.store_cipher {
if let Some(store_cipher) = &self.store_cipher {
key.encode_secure(table_name, store_cipher).to_vec()
} else {
key.encode()

View File

@@ -465,6 +465,15 @@ impl OtherUserIdentity {
let content = self.inner.verification_request_content(methods.clone()).await;
let room = if let Some(room) = self.direct_message_room.read().await.as_ref() {
// Make sure that the user, to be verified, is still in the room
if !room
.active_members()
.await?
.iter()
.any(|member| member.user_id() == self.inner.user_id())
{
room.invite_user_by_id(self.inner.user_id()).await?;
}
room.clone()
} else if let Some(room) =
self.client.create_dm_room(self.inner.user_id().to_owned()).await?

View File

@@ -404,11 +404,12 @@ impl Client {
#[cfg(feature = "encryption")]
fn get_dm_room(&self, user_id: &UserId) -> Option<room::Joined> {
let rooms = self.joined_rooms();
let room_pairs: Vec<_> =
rooms.iter().map(|r| (r.room_id().to_owned(), r.direct_target())).collect();
trace!(rooms = ?room_pairs, "Finding direct room");
let room = rooms.into_iter().find(|r| r.direct_target().as_deref() == Some(user_id));
// Find the room we share with the `user_id` and only with `user_id`
let room = rooms.into_iter().find(|r| {
let targets = r.direct_targets();
targets.len() == 1 && targets.contains(user_id)
});
trace!(?room, "Found room");
room

View File

@@ -47,8 +47,11 @@ pub type HttpResult<T> = std::result::Result<T, HttpError>;
/// representation if the endpoint is from that.
#[derive(Error, Debug)]
pub enum RumaApiError {
/// A client API response error.
#[error(transparent)]
ClientApi(ruma::api::client::Error),
/// Another API response error.
#[error(transparent)]
Other(ruma::api::error::MatrixError),
}

View File

@@ -60,6 +60,6 @@ pub use account::Account;
pub use client::{Client, ClientBuildError, ClientBuilder, LoopCtrl};
#[cfg(feature = "image-proc")]
pub use error::ImageError;
pub use error::{Error, HttpError, HttpResult, Result};
pub use error::{Error, HttpError, HttpResult, Result, RumaApiError};
pub use http_client::HttpSend;
pub use room_member::RoomMember;

View File

@@ -1,4 +1,4 @@
use std::{ops::Deref, sync::Arc};
use std::{collections::BTreeMap, ops::Deref, sync::Arc};
use futures_core::stream::Stream;
use matrix_sdk_base::{
@@ -8,6 +8,7 @@ use matrix_sdk_base::{
use matrix_sdk_common::locks::Mutex;
use ruma::{
api::client::{
config::set_global_account_data,
filter::{LazyLoadOptions, RoomEventFilter},
membership::{get_member_events, join_room_by_id, leave_room},
message::get_message_events::{self, v3::Direction},
@@ -16,10 +17,11 @@ use ruma::{
},
assign,
events::{
direct::DirectEvent,
room::{history_visibility::HistoryVisibility, MediaSource},
tag::{TagInfo, TagName},
AnyRoomAccountDataEvent, AnyStateEvent, AnySyncStateEvent, RedactContent,
RedactedEventContent, RoomAccountDataEvent, RoomAccountDataEventContent,
AnyRoomAccountDataEvent, AnyStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RedactContent, RedactedEventContent, RoomAccountDataEvent, RoomAccountDataEventContent,
RoomAccountDataEventType, StateEventContent, StateEventType, StaticEventContent,
SyncStateEvent,
},
@@ -30,7 +32,7 @@ use ruma::{
use crate::{
media::{MediaFormat, MediaRequest},
room::RoomType,
BaseRoom, Client, HttpError, HttpResult, Result, RoomMember,
BaseRoom, Client, Error, HttpError, HttpResult, Result, RoomMember,
};
/// A struct containing methods that are common for Joined, Invited and Left
@@ -817,6 +819,56 @@ impl Common {
let request = delete_tag::v3::Request::new(&user_id, self.inner.room_id(), tag.as_ref());
self.client.send(request, None).await
}
/// Sets whether this room is a DM.
///
/// When setting this room as DM, it will be marked as DM for all active
/// members of the room. When unsetting this room as DM, it will be
/// unmarked as DM for all users, not just the members.
///
/// # Arguments
/// * `is_direct` - Whether to mark this room as direct.
pub async fn set_is_direct(&self, is_direct: bool) -> Result<()> {
let user_id = self
.client
.user_id()
.await
.ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?;
let mut content = self
.client
.store()
.get_account_data_event(GlobalAccountDataEventType::Direct)
.await?
.map(|e| e.deserialize_as::<DirectEvent>())
.transpose()?
.map(|e| e.content)
.unwrap_or_else(|| ruma::events::direct::DirectEventContent(BTreeMap::new()));
let this_room_id = self.inner.room_id();
if is_direct {
let room_members = self.active_members().await?;
for member in room_members {
let entry = content.entry(member.user_id().to_owned()).or_default();
if !entry.iter().any(|room_id| room_id == this_room_id) {
entry.push(this_room_id.to_owned());
}
}
} else {
for (_, list) in content.iter_mut() {
list.retain(|room_id| *room_id != this_room_id);
}
// Remove user ids that don't have any room marked as DM
content.retain(|_, list| !list.is_empty());
}
let request = set_global_account_data::v3::Request::new(&content, &user_id)?;
self.client.send(request, None).await?;
Ok(())
}
}
/// Options for [`messages`][Common::messages].

View File

@@ -80,11 +80,11 @@ fn open_stores_with_path(
) -> Result<(Box<StateStore>, Box<CryptoStore>), OpenStoreError> {
if let Some(passphrase) = passphrase {
let state_store = StateStore::open_with_passphrase(path, passphrase)?;
let crypto_store = state_store.open_crypto_store(Some(passphrase))?;
let crypto_store = state_store.open_crypto_store()?;
Ok((Box::new(state_store), Box::new(crypto_store)))
} else {
let state_store = StateStore::open_with_path(path)?;
let crypto_store = state_store.open_crypto_store(None)?;
let crypto_store = state_store.open_crypto_store()?;
Ok((Box::new(state_store), Box::new(crypto_store)))
}
}