diff --git a/.typos.toml b/.typos.toml index 0ccb29c56..2579ebb9e 100644 --- a/.typos.toml +++ b/.typos.toml @@ -3,6 +3,7 @@ # Or if we're able to ignore certain lines. Fo = "Fo" BA = "BA" +UE = "UE" [files] # Our json files contain a bunch of base64 encoded ed25519 keys which aren't diff --git a/crates/matrix-sdk-appservice/src/webserver/warp.rs b/crates/matrix-sdk-appservice/src/webserver/warp.rs index 378615846..a7b57989e 100644 --- a/crates/matrix-sdk-appservice/src/webserver/warp.rs +++ b/crates/matrix-sdk-appservice/src/webserver/warp.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{net::ToSocketAddrs, result::Result as StdResult}; +use std::net::ToSocketAddrs; use matrix_sdk::{ bytes::Bytes, @@ -168,7 +168,7 @@ mod handlers { _user_id: String, appservice: AppService, request: http::Request, - ) -> StdResult { + ) -> Result { if let Some(user_exists) = appservice.event_handler.users.lock().await.as_mut() { let request = query_user::IncomingRequest::try_from_http_request(request).map_err(Error::from)?; @@ -185,7 +185,7 @@ mod handlers { _room_id: String, appservice: AppService, request: http::Request, - ) -> StdResult { + ) -> Result { if let Some(room_exists) = appservice.event_handler.rooms.lock().await.as_mut() { let request = query_room::IncomingRequest::try_from_http_request(request).map_err(Error::from)?; @@ -202,7 +202,7 @@ mod handlers { _txn_id: String, appservice: AppService, request: http::Request, - ) -> StdResult { + ) -> Result { let incoming_transaction: ruma::api::appservice::event::push_events::v1::IncomingRequest = ruma::api::IncomingRequest::try_from_http_request(request).map_err(Error::from)?; @@ -224,7 +224,7 @@ struct ErrorMessage { message: String, } -pub async fn handle_rejection(err: Rejection) -> std::result::Result { +pub async fn handle_rejection(err: Rejection) -> Result { if err.find::().is_some() || err.find::().is_some() { let code = http::StatusCode::UNAUTHORIZED; let message = "UNAUTHORIZED"; diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index eef9557c7..7d833d524 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -58,7 +58,7 @@ default-features = false features = ["sync", "fs"] [dev-dependencies] -futures = { version = "0.3.15", default-features = false } +futures = { version = "0.3.15", default-features = false, features = ["executor"] } http = "0.2.4" matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" } diff --git a/crates/matrix-sdk-base/src/session.rs b/crates/matrix-sdk-base/src/session.rs index 0bdcb4546..a5f0620f4 100644 --- a/crates/matrix-sdk-base/src/session.rs +++ b/crates/matrix-sdk-base/src/session.rs @@ -15,11 +15,26 @@ //! User sessions. -use ruma::{DeviceId, UserId}; +use ruma::{DeviceIdBox, UserId}; use serde::{Deserialize, Serialize}; /// A user session, containing an access token and information about the /// associated user account. +/// +/// # Example +/// +/// ``` +/// use matrix_sdk_base::Session; +/// use ruma::{DeviceIdBox, user_id}; +/// +/// let session = Session { +/// access_token: "My-Token".to_owned(), +/// user_id: user_id!("@example:localhost"), +/// device_id: DeviceIdBox::from("MYDEVICEID"), +/// }; +/// +/// assert_eq!(session.device_id.as_str(), "MYDEVICEID"); +/// ``` #[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct Session { /// The access token used for this session. @@ -27,5 +42,15 @@ pub struct Session { /// The user the access token was issued for. pub user_id: UserId, /// The ID of the client device - pub device_id: Box, + pub device_id: DeviceIdBox, +} + +impl From for Session { + fn from(response: ruma::api::client::r0::session::login::Response) -> Self { + Self { + access_token: response.access_token, + user_id: response.user_id, + device_id: response.device_id, + } + } } diff --git a/crates/matrix-sdk-common/Cargo.toml b/crates/matrix-sdk-common/Cargo.toml index f9043bffa..16b2df4a8 100644 --- a/crates/matrix-sdk-common/Cargo.toml +++ b/crates/matrix-sdk-common/Cargo.toml @@ -33,7 +33,7 @@ version = "0.1.9" features = ["now"] [target.'cfg(target_arch = "wasm32")'.dependencies] -futures-util = { version = "0.3.15", default-features = false } -futures-locks = { version = "0.6.0", default-features = false } +async-lock = "2.4.0" +futures-util = { version = "0.3.15", default-features = false, features = ["channel"] } wasm-bindgen-futures = "0.4.24" uuid = { version = "0.8.2", default-features = false, features = ["v4", "wasm-bindgen"] } diff --git a/crates/matrix-sdk-common/src/locks.rs b/crates/matrix-sdk-common/src/locks.rs index f3162309e..fbb89f52d 100644 --- a/crates/matrix-sdk-common/src/locks.rs +++ b/crates/matrix-sdk-common/src/locks.rs @@ -1,8 +1,4 @@ -// could switch to futures-lock completely at some point, blocker: -// https://github.com/asomers/futures-locks/issues/34 -// https://www.reddit.com/r/rust/comments/f4zldz/i_audited_3_different_implementation_of_async/ - #[cfg(target_arch = "wasm32")] -pub use futures_locks::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; +pub use async_lock::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; #[cfg(not(target_arch = "wasm32"))] pub use tokio::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index 39dcb3495..ef7d34cdd 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -16,8 +16,9 @@ features = ["docs"] rustdoc-args = ["--cfg", "feature=\"docs\""] [features] -default = [] +default = ["backups_v1"] qrcode = ["matrix-qrcode"] +backups_v1 = [] sled_cryptostore = ["sled"] docs = ["sled_cryptostore"] indexeddb_cryptostore = ["indexed_db_futures", "wasm-bindgen"] @@ -27,6 +28,7 @@ aes = { version = "0.7.4", features = ["ctr"] } aes-gcm = "0.9.2" atomic = "0.5.0" base64 = "0.13.0" +bs58 = "0.4.0" byteorder = "1.4.3" dashmap = "4.0.2" futures-util = { version = "0.3.15", default-features = false } @@ -34,9 +36,13 @@ getrandom = "0.2.3" hmac = "0.11.0" matrix-qrcode = { version = "0.2.0", path = "../matrix-qrcode", optional = true } matrix-sdk-common = { version = "0.4.0", path = "../matrix-sdk-common" } -olm-rs = { version = "2.0.1", features = ["serde"] } +olm-rs = { version = "2.1", features = ["serde"] } pbkdf2 = { version = "0.9.0", default-features = false } -ruma = { git = "https://github.com/ruma/ruma", rev = "ac6ecc3e5", features = ["client-api-c", "unstable-pre-spec"] } +rand = "0.8.4" +ruma = { git = "https://github.com/ruma/ruma", rev = "ac6ecc3e5", features = [ + "client-api-c", + "unstable-pre-spec", +] } serde = { version = "1.0.126", features = ["derive", "rc"] } serde_json = "1.0.64" sha2 = "0.9.5" @@ -50,17 +56,24 @@ indexed_db_futures = { version = "0.2.0", optional = true } wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"], optional = true } [dev-dependencies] -futures = { version = "0.3.15", default-features = false } +futures = { version = "0.3.15", default-features = false, features = ["executor"] } http = "0.2.4" indoc = "1.0.3" matches = "0.1.8" matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] +criterion = { version = "0.3.4", features = [ + "async", + "async_tokio", + "html_reports", +] } proptest = "1.0.0" -criterion = { version = "0.3.4", features = ["async", "async_tokio", "html_reports"] } -tokio = { version = "1.7.1", default-features = false, features = ["rt-multi-thread", "macros"] } tempfile = "3.2.0" +tokio = { version = "1.7.1", default-features = false, features = [ + "rt-multi-thread", + "macros", +] } [target.'cfg(target_os = "linux")'.dev-dependencies] criterion = { version = "0.3.4", features = ["async", "html_reports"] } diff --git a/crates/matrix-sdk-crypto/src/backups/keys/backup.rs b/crates/matrix-sdk-crypto/src/backups/keys/backup.rs new file mode 100644 index 000000000..d80e944e9 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/backups/keys/backup.rs @@ -0,0 +1,148 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::BTreeMap, + sync::{Arc, Mutex}, +}; + +use olm_rs::pk::OlmPkEncryption; +use ruma::{ + api::client::r0::backup::{KeyBackupData, KeyBackupDataInit, SessionDataInit}, + DeviceKeyId, UserId, +}; +use zeroize::Zeroizing; + +use super::recovery::DecodeError; +use crate::olm::InboundGroupSession; + +#[derive(Debug)] +struct InnerBackupKey { + key: [u8; MegolmV1BackupKey::KEY_SIZE], + signatures: BTreeMap>, + version: Mutex>, +} + +/// The public part of a backup key. +#[derive(Clone)] +pub struct MegolmV1BackupKey { + inner: Arc, +} + +impl std::fmt::Debug for MegolmV1BackupKey { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("MegolmV1BackupKey") + .field("key", &self.to_base64()) + .field("version", &self.backup_version()) + .finish() + } +} + +impl MegolmV1BackupKey { + const KEY_SIZE: usize = 32; + + pub(super) fn new(public_key: &str, version: Option) -> Self { + let key = Self::from_base64(public_key).expect("Invalid backup key"); + + if let Some(version) = version { + key.set_version(version) + } + + key + } + + /// Get the full name of the backup algorithm this backup key supports. + pub fn backup_algorithm(&self) -> &str { + "m.megolm_backup.v1.curve25519-aes-sha2" + } + + /// Get all the signatures of this `MegolmV1BackupKey`. + pub fn signatures(&self) -> BTreeMap> { + self.inner.signatures.to_owned() + } + + /// Try to create a new `MegolmV1BackupKey` from a base 64 encoded string. + pub fn from_base64(public_key: &str) -> Result { + let mut key = [0u8; Self::KEY_SIZE]; + let decoded_key = crate::utilities::decode(public_key)?; + + if decoded_key.len() != Self::KEY_SIZE { + Err(DecodeError::Length(Self::KEY_SIZE, decoded_key.len())) + } else { + key.copy_from_slice(&decoded_key); + + let inner = + InnerBackupKey { key, signatures: Default::default(), version: Mutex::new(None) }; + + Ok(MegolmV1BackupKey { inner: inner.into() }) + } + } + + /// Convert the [`MegolmV1BackupKey`] to a base 64 encoded string. + pub fn to_base64(&self) -> String { + crate::utilities::encode(&self.inner.key) + } + + /// Get the backup version that this key is used with, if any. + pub fn backup_version(&self) -> Option { + self.inner.version.lock().unwrap().clone() + } + + /// Set the backup version that this `MegolmV1BackupKey` will be used with. + /// + /// The key won't be able to encrypt room keys unless a version has been + /// set. + pub fn set_version(&self, version: String) { + *self.inner.version.lock().unwrap() = Some(version); + } + + pub(crate) async fn encrypt(&self, session: InboundGroupSession) -> KeyBackupData { + let pk = OlmPkEncryption::new(&self.to_base64()); + + // It's ok to truncate here, there's a semantic difference only between + // 0 and 1+ anyways. + let forwarded_count = (session.forwarding_key_chain().len() as u32).into(); + let first_message_index = session.first_known_index().into(); + + // Convert our key to the backup representation. + let key = session.to_backup().await; + + // The key gets zeroized in `BackedUpRoomKey` but we're creating a copy + // here that won't, so let's wrap it up in a `Zeroizing` struct. + let key = + Zeroizing::new(serde_json::to_string(&key).expect("Can't serialize exported room key")); + + let message = pk.encrypt(&key); + + let session_data = SessionDataInit { + ephemeral: message.ephemeral_key, + ciphertext: message.ciphertext, + mac: message.mac, + } + .into(); + + KeyBackupDataInit { + first_message_index, + forwarded_count, + // TODO is this actually used anywhere? seems to be completely + // useless and requires us to get the Device out of the store? + // Also should this be checked at the time of the backup or at the + // time of the room key receival? + is_verified: false, + session_data, + } + .into() + } +} diff --git a/crates/matrix-sdk-crypto/src/backups/keys/mod.rs b/crates/matrix-sdk-crypto/src/backups/keys/mod.rs new file mode 100644 index 000000000..145d0c011 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/backups/keys/mod.rs @@ -0,0 +1,54 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Module for the keys that are used to backup room keys. +//! +//! The backup key is split into two parts: +//! +//! ```text +//! ┌───────────────────────────┐ +//! │ RecoveryKey | BackupKey │ +//! └───────────────────────────┘ +//! ``` +//! +//! 1. RecoveryKey, a private Curve25519 key that is used to decrypt backed up +//! room keys. +//! 2. BackupKey, the public part of the Curve25519 RecoveryKey, this one is +//! used to encrypt room keys that get backed up. +//! +//! The `RecoveryKey` can be derived from a passphrase or randomly generated. +//! To allow other devices access to this key you'll either have to re-enter the +//! passphrase or the key as a base58 encoded string or received from another +//! device using the secret sharing mechanism. +//! +//! In practice deriving a recovery key from a passphrase isn't done, and is +//! **not** supported by the spec. Instead the secret storage key that encrypts +//! secrets and puts them into your global account data gets derived from a +//! passphrase and presented as a base58 encoded string, confusingly also called +//! recovery key. +//! +//! The `RecoveryKey` can also be uploaded to your account data as an +//! `m.megolm.v1` event, which gets encrypted by the SSSS key. +//! +//! The `MegolmV1BackupKey` is used to encrypt individual room keys so they can +//! be uploaded to the homeserver. +//! +//! The `MegolmV1BackupKey` is a public key and is uploaded to the server using +//! the `/room_keys/version` API endpoint. + +mod backup; +mod recovery; + +pub use backup::MegolmV1BackupKey; +pub use recovery::{DecodeError, PickledRecoveryKey, RecoveryKey}; diff --git a/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs b/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs new file mode 100644 index 000000000..88ca3dfbb --- /dev/null +++ b/crates/matrix-sdk-crypto/src/backups/keys/recovery.rs @@ -0,0 +1,354 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + convert::TryFrom, + io::{Cursor, Read}, +}; + +use aes::cipher::generic_array::GenericArray; +use aes_gcm::{ + aead::{Aead, NewAead}, + Aes256Gcm, +}; +use bs58; +use olm_rs::{ + errors::OlmPkDecryptionError, + pk::{OlmPkDecryption, PkMessage}, +}; +use rand::{thread_rng, Error as RandomError, Fill}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use zeroize::{Zeroize, Zeroizing}; + +use super::MegolmV1BackupKey; +use crate::utilities::{decode_url_safe, encode, encode_url_safe}; + +const NONCE_SIZE: usize = 12; + +/// Error type for the decoding of a RecoveryKey. +#[derive(Debug, Error)] +pub enum DecodeError { + /// The decoded recovery key has an invalid prefix. + #[error("The decoded recovery key has an invalid prefix: expected {0:?}, got {1:?}")] + Prefix([u8; 2], [u8; 2]), + /// The parity byte of the recovery key didn't match. + #[error("The parity byte of the recovery key doesn't match: expected {0:?}, got {1:?}")] + Parity(u8, u8), + /// The recovery key has an invalid length. + #[error("The decoded recovery key has a invalid length: expected {0}, got {1}")] + Length(usize, usize), + /// The recovry key isn't valid base58. + #[error(transparent)] + Base58(#[from] bs58::decode::Error), + /// The recovery key isn't valid base64. + #[error(transparent)] + Base64(#[from] base64::DecodeError), + /// The recovery key is too short, we couldn't read enough data. + #[error(transparent)] + Io(#[from] std::io::Error), +} + +#[derive(Debug, Error)] +pub enum UnpicklingError { + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error("Couldn't decrypt the pickle: {0}")] + Decryption(String), + #[error(transparent)] + Decode(#[from] DecodeError), +} + +/// The private part of a backup key. +#[derive(Zeroize)] +pub struct RecoveryKey { + inner: [u8; RecoveryKey::KEY_SIZE], +} + +impl Drop for RecoveryKey { + fn drop(&mut self) { + self.inner.zeroize() + } +} + +impl std::fmt::Debug for RecoveryKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RecoveryKey").finish() + } +} + +/// The pickled version of a recovery key. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PickledRecoveryKey(String); + +impl AsRef for PickledRecoveryKey { + fn as_ref(&self) -> &str { + &self.0 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct InnerPickle { + version: u8, + nonce: String, + ciphertext: String, +} + +impl TryFrom for RecoveryKey { + type Error = DecodeError; + + fn try_from(value: String) -> Result { + Self::from_base58(&value) + } +} + +impl std::fmt::Display for RecoveryKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let string = Zeroizing::new(self.to_base58()); + + let string = Zeroizing::new( + string + .chars() + .collect::>() + .chunks(Self::DISPLAY_CHUNK_SIZE) + .map(|c| c.iter().collect::()) + .collect::>() + .join(" "), + ); + + write!(f, "{}", string.as_str()) + } +} + +impl RecoveryKey { + const KEY_SIZE: usize = 32; + const PREFIX: [u8; 2] = [0x8b, 0x01]; + const PREFIX_PARITY: u8 = Self::PREFIX[0] ^ Self::PREFIX[1]; + const DISPLAY_CHUNK_SIZE: usize = 4; + + fn parity_byte(bytes: &[u8]) -> u8 { + bytes.iter().fold(Self::PREFIX_PARITY, |acc, x| acc ^ x) + } + + /// Create a new random recovery key. + pub fn new() -> Result { + let mut rng = thread_rng(); + + let mut key = [0u8; Self::KEY_SIZE]; + key.try_fill(&mut rng)?; + + Ok(Self { inner: key }) + } + + /// Create a new recovery key from the given byte array. + /// + /// **Warning**: You need to make sure that the byte array contains correct + /// random data, either by using a random number generator or by using an + /// exported version of a previously created [`RecoveryKey`]. + pub fn from_bytes(key: [u8; Self::KEY_SIZE]) -> Self { + Self { inner: key } + } + + /// Try to create a [`RecoveryKey`] from a base64 export of a `RecoveryKey`. + pub fn from_base64(key: &str) -> Result { + let decoded = Zeroizing::new(crate::utilities::decode(key)?); + + if decoded.len() != Self::KEY_SIZE { + Err(DecodeError::Length(decoded.len(), Self::KEY_SIZE)) + } else { + let mut key = [0u8; Self::KEY_SIZE]; + key.copy_from_slice(&decoded); + + Ok(Self::from_bytes(key)) + } + } + + /// Export the `RecoveryKey` as a base64 encoded string. + pub fn to_base64(&self) -> String { + encode(self.inner) + } + + /// Try to create a [`RecoveryKey`] from a base58 export of a `RecoveryKey`. + pub fn from_base58(value: &str) -> Result { + // Remove any whitespace we might have + let value: String = value.chars().filter(|c| !c.is_whitespace()).collect(); + + let decoded = bs58::decode(value).with_alphabet(bs58::Alphabet::BITCOIN).into_vec()?; + let mut decoded = Cursor::new(decoded); + + let mut prefix = [0u8; 2]; + let mut key = [0u8; Self::KEY_SIZE]; + let mut expected_parity = [0u8; 1]; + + decoded.read_exact(&mut prefix)?; + decoded.read_exact(&mut key)?; + decoded.read_exact(&mut expected_parity)?; + + let expected_parity = expected_parity[0]; + let parity = Self::parity_byte(key.as_ref()); + + let _ = Zeroizing::new(decoded.into_inner()); + + if prefix != Self::PREFIX { + Err(DecodeError::Prefix(Self::PREFIX, prefix)) + } else if expected_parity != parity { + Err(DecodeError::Parity(expected_parity, parity)) + } else { + Ok(Self::from_bytes(key)) + } + } + + /// Export the `RecoveryKey` as a base58 encoded string. + pub fn to_base58(&self) -> String { + let bytes = Zeroizing::new( + [ + Self::PREFIX.as_ref(), + self.inner.as_ref(), + [Self::parity_byte(self.inner.as_ref())].as_ref(), + ] + .concat(), + ); + + bs58::encode(bytes.as_slice()).with_alphabet(bs58::Alphabet::BITCOIN).into_string() + } + + fn get_pk_decrytpion(&self) -> OlmPkDecryption { + OlmPkDecryption::from_bytes(self.inner.as_ref()) + .expect("Can't generate a libom PkDecryption object from our private key") + } + + /// Extract the megolm.v1 public key from this `RecoveryKey`. + pub fn megolm_v1_public_key(&self) -> MegolmV1BackupKey { + let pk = self.get_pk_decrytpion(); + let public_key = MegolmV1BackupKey::new(pk.public_key(), None); + public_key + } + + /// Export this [`RecoveryKey`] as an encrypted pickle that can be safely + /// stored. + pub fn pickle(&self, pickle_key: &[u8]) -> PickledRecoveryKey { + let key = GenericArray::from_slice(pickle_key); + let cipher = Aes256Gcm::new(key); + + let mut nonce = vec![0u8; NONCE_SIZE]; + let mut rng = thread_rng(); + + nonce.try_fill(&mut rng).expect("Can't generate random nocne to pickle the recovery key"); + let nonce = GenericArray::from_slice(nonce.as_slice()); + + let ciphertext = + cipher.encrypt(nonce, self.inner.as_ref()).expect("Can't encrypt recovery key"); + + let ciphertext = encode_url_safe(ciphertext); + + let pickle = + InnerPickle { version: 1, nonce: encode_url_safe(nonce.as_slice()), ciphertext }; + + PickledRecoveryKey(serde_json::to_string(&pickle).expect("Can't encode pickled signing")) + } + + /// Try to import a `RecoveryKey` from a previously exported pickle. + pub fn from_pickle( + pickle: PickledRecoveryKey, + pickle_key: &[u8], + ) -> Result { + let pickled: InnerPickle = serde_json::from_str(pickle.as_ref())?; + + let key = GenericArray::from_slice(pickle_key); + let cipher = Aes256Gcm::new(key); + + let nonce = decode_url_safe(pickled.nonce).map_err(DecodeError::from)?; + let nonce = GenericArray::from_slice(&nonce); + let ciphertext = &decode_url_safe(pickled.ciphertext).map_err(DecodeError::from)?; + + let decrypted = cipher + .decrypt(nonce, ciphertext.as_slice()) + .map_err(|e| UnpicklingError::Decryption(e.to_string()))?; + + if decrypted.len() != Self::KEY_SIZE { + Err(DecodeError::Length(decrypted.len(), Self::KEY_SIZE).into()) + } else { + let mut key = [0u8; Self::KEY_SIZE]; + key.copy_from_slice(&decrypted); + + Ok(Self { inner: key }) + } + } + + /// Try to decrypt the given ciphertext using this `RecoveryKey`. + /// + /// This will use the [`m.megolm_backup.v1.curve25519-aes-sha2`] algorithm + /// to decrypt the given ciphertext. + /// + /// [`m.megolm_backup.v1.curve25519-aes-sha2`]: + /// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + pub fn decrypt_v1( + &self, + mac: String, + ephemeral_key: String, + ciphertext: String, + ) -> Result { + let message = PkMessage::new(mac, ephemeral_key, ciphertext); + let pk = self.get_pk_decrytpion(); + + pk.decrypt(message) + } +} + +#[cfg(test)] +mod test { + use super::{DecodeError, RecoveryKey}; + + const TEST_KEY: [u8; 32] = [ + 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, + 0x45, 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, + 0x2C, 0x2A, + ]; + + #[test] + fn base64_decoding() -> Result<(), DecodeError> { + let key = RecoveryKey::new().expect("Can't create a new recovery key"); + + let base64 = key.to_base64(); + let decoded_key = RecoveryKey::from_base64(&base64)?; + assert_eq!(key.inner, decoded_key.inner, "The decode key does't match the original"); + + RecoveryKey::from_base64("i").expect_err("The recovery key is too short"); + + Ok(()) + } + + #[test] + fn base58_decoding() -> Result<(), DecodeError> { + let key = RecoveryKey::new().expect("Can't create a new recovery key"); + + let base64 = key.to_base58(); + let decoded_key = RecoveryKey::from_base58(&base64)?; + assert_eq!(key.inner, decoded_key.inner, "The decode key does't match the original"); + + let test_key = + RecoveryKey::from_base58("EsTcLW2KPGiFwKEA3As5g5c4BXwkqeeJZJV8Q9fugUMNUE4d")?; + assert_eq!(test_key.inner, TEST_KEY, "The decoded recovery key doesn't match the test key"); + + let test_key = RecoveryKey::from_base58( + "EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4d", + )?; + assert_eq!(test_key.inner, TEST_KEY, "The decoded recovery key doesn't match the test key"); + + RecoveryKey::from_base58("EsTc LW2K PGiF wKEA 3As5 g5c4 BXwk qeeJ ZJV8 Q9fu gUMN UE4e") + .expect_err("Can't create a recovery key if the parity byte is invalid"); + + Ok(()) + } +} diff --git a/crates/matrix-sdk-crypto/src/backups/mod.rs b/crates/matrix-sdk-crypto/src/backups/mod.rs new file mode 100644 index 000000000..bc0fe9b94 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/backups/mod.rs @@ -0,0 +1,519 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Server-side backup support for room keys +//! +//! This module largely implements support for server-side backups using the +//! `m.megolm_backup.v1.curve25519-aes-sha2` backup algorithm. +//! +//! Due to various flaws in this backup algorithm it is **not** recommended to +//! use this module or any of its functionality. The module is only provided for +//! backwards compatibility. +//! +//! [spec]: https://spec.matrix.org/unstable/client-server-api/#server-side-key-backups + +use std::{ + collections::{BTreeMap, BTreeSet}, + sync::Arc, +}; + +use matrix_sdk_common::{locks::RwLock, uuid::Uuid}; +use ruma::{ + api::client::r0::backup::RoomKeyBackup, DeviceKeyAlgorithm, DeviceKeyId, RoomId, UserId, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tracing::{debug, info, instrument, trace, warn}; + +use crate::{ + olm::{Account, InboundGroupSession}, + store::{BackupKeys, Changes, RoomKeyCounts, Store}, + CryptoStoreError, KeysBackupRequest, OutgoingRequest, +}; + +mod keys; + +pub use keys::{DecodeError, MegolmV1BackupKey, PickledRecoveryKey, RecoveryKey}; +pub use olm_rs::errors::OlmPkDecryptionError; + +/// A state machine that handles backing up room keys. +/// +/// The state machine can be activated using the +/// [`BackupMachine::enable_backup_v1`] method. After the state machine has been +/// enabled a request that will upload encrypted room keys can be generated +/// using the [`BackupMachine::backup`] method. +#[derive(Debug, Clone)] +pub struct BackupMachine { + account: Account, + store: Store, + backup_key: Arc>>, + pending_backup: Arc>>, +} + +#[derive(Debug, Clone)] +struct PendingBackup { + request_id: Uuid, + request: KeysBackupRequest, + sessions: BTreeMap>>, +} + +impl PendingBackup { + fn session_was_part_of_the_backup(&self, session: &InboundGroupSession) -> bool { + self.sessions + .get(session.room_id()) + .and_then(|r| r.get(session.sender_key()).map(|s| s.contains(session.session_id()))) + .unwrap_or(false) + } +} + +impl From for OutgoingRequest { + fn from(b: PendingBackup) -> Self { + OutgoingRequest { request_id: b.request_id, request: Arc::new(b.request.into()) } + } +} + +impl BackupMachine { + const BACKUP_BATCH_SIZE: usize = 100; + + pub(crate) fn new( + account: Account, + store: Store, + backup_key: Option, + ) -> Self { + Self { + account, + store, + backup_key: RwLock::new(backup_key).into(), + pending_backup: RwLock::new(None).into(), + } + } + + /// Are we able to back up room keys to the server? + pub async fn enabled(&self) -> bool { + self.backup_key.read().await.as_ref().map(|b| b.backup_version().is_some()).unwrap_or(false) + } + + /// Verify some backup auth data that we downloaded from the server. + /// + /// The auth data should be fetched from the server using the + /// [`/room_keys/version`] endpoint. + /// + /// Ruma models this using the [`BackupAlgorithm`] struct, but care needs to + /// be taken if that struct is used. Some clients might use unspecced fields + /// and the given Ruma struct will lose the unspecced fields after a + /// serialization cycle. + /// + /// [`BackupAlgorithm`]: ruma::api::client::r0::backup::BackupAlgorithm + /// [`/room_keys/version`]: https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3room_keysversion + pub async fn verify_backup( + &self, + mut serialized_auth_data: Value, + ) -> Result { + #[derive(Debug, Serialize, Deserialize)] + struct AuthData { + public_key: String, + #[serde(default)] + signatures: BTreeMap>, + #[serde(flatten)] + extra: BTreeMap, + } + + let auth_data: AuthData = serde_json::from_value(serialized_auth_data.clone())?; + + trace!(?auth_data, "Verifying backup auth data"); + + Ok(if let Some(signatures) = auth_data.signatures.get(self.store.user_id()) { + for device_key_id in signatures.keys() { + if device_key_id.algorithm() == DeviceKeyAlgorithm::Ed25519 { + if device_key_id.device_id() == self.account.device_id() { + let result = self.account.is_signed(&mut serialized_auth_data); + + trace!(?result, "Checking auth data signature of our own device"); + + if result.is_ok() { + return Ok(true); + } + } else { + let device = self + .store + .get_device(self.store.user_id(), device_key_id.device_id()) + .await?; + + trace!( + device_id = device_key_id.device_id().as_str(), + "Checking backup auth data for device" + ); + + if let Some(device) = device { + if device.verified() + && device.is_signed_by_device(&mut serialized_auth_data).is_ok() + { + return Ok(true); + } + } else { + trace!( + device_id = device_key_id.device_id().as_str(), + "Device not found, can't check signature" + ); + } + } + } + } + + false + } else { + false + }) + } + + /// Activate the given backup key to be used to encrypt and backup room + /// keys. + /// + /// This will use the [`m.megolm_backup.v1.curve25519-aes-sha2`] algorithm + /// to encrypt the room keys. + /// + /// [`m.megolm_backup.v1.curve25519-aes-sha2`]: + /// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + pub async fn enable_backup_v1(&self, key: MegolmV1BackupKey) -> Result<(), CryptoStoreError> { + if key.backup_version().is_some() { + *self.backup_key.write().await = Some(key.clone()); + info!(backup_key =? key, "Activated a backup"); + } else { + warn!(backup_key =? key, "Tried to activate a backup without having the backup key uploaded"); + } + + Ok(()) + } + + /// Get the number of backed up room keys and the total number of room keys. + pub async fn room_key_counts(&self) -> Result { + self.store.inbound_group_session_counts().await + } + + /// Disable and reset our backup state. + /// + /// This will remove any pending backup request, remove the backup key and + /// reset the backup state of each room key we have. + #[instrument(skip(self))] + pub async fn disable_backup(&self) -> Result<(), CryptoStoreError> { + debug!("Disabling key backup and resetting backup state for room keys"); + + self.backup_key.write().await.take(); + self.pending_backup.write().await.take(); + + self.store.reset_backup_state().await?; + + debug!("Done disabling backup"); + + Ok(()) + } + + /// Store the recovery key in the cryptostore. + /// + /// This is useful if the client wants to support gossiping of the backup + /// key. + pub async fn save_recovery_key( + &self, + recovery_key: Option, + version: Option, + ) -> Result<(), CryptoStoreError> { + let changes = Changes { recovery_key, backup_version: version, ..Default::default() }; + self.store.save_changes(changes).await + } + + /// Get the backup keys we have saved in our crypto store. + pub async fn get_backup_keys(&self) -> Result { + self.store.load_backup_keys().await + } + + /// Encrypt a batch of room keys and return a request that needs to be sent + /// out to backup the room keys. + pub async fn backup(&self) -> Result, CryptoStoreError> { + let mut request = self.pending_backup.write().await; + + if let Some(request) = &*request { + trace!("Backing up, returning an existing request"); + + Ok(Some(request.clone().into())) + } else { + trace!("Backing up, creating a new request"); + + let new_request = self.backup_helper().await?; + *request = new_request.clone(); + + Ok(new_request.map(|r| r.into())) + } + } + + pub(crate) async fn mark_request_as_sent( + &self, + request_id: Uuid, + ) -> Result<(), CryptoStoreError> { + let mut request = self.pending_backup.write().await; + + if let Some(r) = &*request { + if r.request_id == request_id { + let sessions: Vec<_> = self + .store + .get_inbound_group_sessions() + .await? + .into_iter() + .filter(|s| r.session_was_part_of_the_backup(s)) + .collect(); + + for session in &sessions { + session.mark_as_backed_up() + } + + trace!(request_id =? r.request_id, keys =? r.sessions, "Marking room keys as backed up"); + + let changes = Changes { inbound_group_sessions: sessions, ..Default::default() }; + self.store.save_changes(changes).await?; + + let counts = self.store.inbound_group_session_counts().await?; + + trace!( + room_key_counts =? counts, + request_id =? r.request_id, keys =? r.sessions, "Marked room keys as backed up" + ); + + *request = None; + } else { + warn!( + expected = r.request_id.to_string().as_str(), + got = request_id.to_string().as_str(), + "Tried to mark a pending backup as sent but the request id didn't match" + ) + } + } else { + warn!( + request_id = request_id.to_string().as_str(), + "Tried to mark a pending backup as sent but there isn't a backup pending" + ); + } + + Ok(()) + } + + async fn backup_helper(&self) -> Result, CryptoStoreError> { + if let Some(backup_key) = &*self.backup_key.read().await { + if let Some(version) = backup_key.backup_version() { + let sessions = + self.store.inbound_group_sessions_for_backup(Self::BACKUP_BATCH_SIZE).await?; + + if !sessions.is_empty() { + let key_count = sessions.len(); + let (backup, session_record) = Self::backup_keys(sessions, backup_key).await; + + info!( + key_count = key_count, + keys =? session_record, + backup_key =? backup_key, + "Successfully created a room keys backup request" + ); + + let request = PendingBackup { + request_id: Uuid::new_v4(), + request: KeysBackupRequest { version, rooms: backup }, + sessions: session_record, + }; + + Ok(Some(request)) + } else { + trace!(?backup_key, "No room keys need to be backed up"); + Ok(None) + } + } else { + warn!("Trying to backup room keys but the backup key wasn't uploaded"); + Ok(None) + } + } else { + warn!("Trying to backup room keys but no backup key was found"); + Ok(None) + } + } + + /// Backup all the non-backed up room keys we know about + async fn backup_keys( + sessions: Vec, + backup_key: &MegolmV1BackupKey, + ) -> (BTreeMap, BTreeMap>>) + { + let mut backup: BTreeMap = BTreeMap::new(); + let mut session_record: BTreeMap>> = + BTreeMap::new(); + + for session in sessions.into_iter() { + let room_id = session.room_id().to_owned(); + let session_id = session.session_id().to_owned(); + let sender_key = session.sender_key().to_owned(); + let session = backup_key.encrypt(session).await; + + session_record + .entry(room_id.clone()) + .or_default() + .entry(sender_key) + .or_default() + .insert(session_id.clone()); + + backup + .entry(room_id) + .or_insert_with(|| RoomKeyBackup::new(BTreeMap::new())) + .sessions + .insert(session_id, session); + } + + (backup, session_record) + } +} + +#[cfg(test)] +mod test { + use matrix_sdk_test::async_test; + use ruma::{room_id, user_id, DeviceIdBox, RoomId, UserId}; + + use super::RecoveryKey; + use crate::{OlmError, OlmMachine}; + + fn alice_id() -> UserId { + user_id!("@alice:example.org") + } + + fn alice_device_id() -> DeviceIdBox { + "JLAFKJWSCS".into() + } + + fn room_id() -> RoomId { + room_id!("!test:localhost") + } + + fn room_id2() -> RoomId { + room_id!("!test2:localhost") + } + + async fn backup_flow(machine: OlmMachine) -> Result<(), OlmError> { + let backup_machine = machine.backup_machine(); + let counts = backup_machine.store.inbound_group_session_counts().await?; + + assert_eq!(counts.total, 0, "Initially no keys exist"); + assert_eq!(counts.backed_up, 0, "Initially no backed up keys exist"); + + machine.create_outbound_group_session_with_defaults(&room_id()).await?; + machine.create_outbound_group_session_with_defaults(&room_id2()).await?; + + let counts = backup_machine.store.inbound_group_session_counts().await?; + assert_eq!(counts.total, 2, "Two room keys need to exist in the store"); + assert_eq!(counts.backed_up, 0, "No room keys have been backed up yet"); + + let recovery_key = RecoveryKey::new().expect("Can't create new recovery key"); + let backup_key = recovery_key.megolm_v1_public_key(); + backup_key.set_version("1".to_owned()); + + backup_machine.enable_backup_v1(backup_key).await?; + + let request = + backup_machine.backup().await?.expect("Created a backup request successfully"); + assert_eq!( + Some(request.request_id), + backup_machine.backup().await?.map(|r| r.request_id), + "Calling backup again without uploading creates the same backup request" + ); + + backup_machine.mark_request_as_sent(request.request_id).await?; + + let counts = backup_machine.store.inbound_group_session_counts().await?; + assert_eq!(counts.total, 2); + assert_eq!(counts.backed_up, 2, "All room keys have been backed up"); + + assert!( + backup_machine.backup().await?.is_none(), + "No room keys need to be backed up, no request needs to be created" + ); + + backup_machine.disable_backup().await?; + + let counts = backup_machine.store.inbound_group_session_counts().await?; + assert_eq!(counts.total, 2); + assert_eq!( + counts.backed_up, 0, + "Disabling the backup resets the backup flag on the room keys" + ); + + Ok(()) + } + + #[async_test] + async fn memory_store_backups() -> Result<(), OlmError> { + let machine = OlmMachine::new(&alice_id(), &alice_device_id()); + + backup_flow(machine).await + } + + #[async_test] + #[cfg(feature = "sled_cryptostore")] + async fn default_store_backups() -> Result<(), OlmError> { + use tempfile::tempdir; + + let tmpdir = tempdir().expect("Can't create a temporary dir"); + let machine = OlmMachine::new_with_default_store( + &alice_id(), + &alice_device_id(), + tmpdir.as_ref(), + None, + ) + .await?; + + backup_flow(machine).await + } + + #[async_test] + #[cfg(feature = "sled_cryptostore")] + async fn recovery_key_storing() -> Result<(), OlmError> { + use tempfile::tempdir; + + let tmpdir = tempdir().expect("Can't create a temporary dir"); + let machine = OlmMachine::new_with_default_store( + &alice_id(), + &alice_device_id(), + tmpdir.as_ref(), + Some("test"), + ) + .await?; + let backup_machine = machine.backup_machine(); + + let recovery_key = RecoveryKey::new().expect("Can't create new recovery key"); + let encoded_key = recovery_key.to_base64(); + + backup_machine.save_recovery_key(Some(recovery_key), Some("1".to_owned())).await?; + + let loded_backup = backup_machine.get_backup_keys().await?; + + assert_eq!( + encoded_key, + loded_backup + .recovery_key + .expect("The recovery key wasn't loaded from the store") + .to_base64(), + "The loaded key matches to the one we stored" + ); + + assert_eq!( + Some("1"), + loded_backup.backup_version.as_deref(), + "The loaded version matches to the one we stored" + ); + + Ok(()) + } +} diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index db30dc086..7bbd9c5fb 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -770,6 +770,8 @@ impl GossipMachine { return Ok(None); }; + let secret = std::mem::take(&mut event.content.secret); + if let Some(request) = self.store.get_outgoing_secret_requests(request_id).await? { match &request.info { SecretInfo::KeyRequest(_) => { @@ -791,30 +793,32 @@ impl GossipMachine { self.store.get_device_from_curve_key(&event.sender, sender_key).await? { if device.verified() { - match self - .store - .import_secret( - secret_name, - std::mem::take(&mut event.content.secret), - ) - .await - { - Ok(_) => self.mark_as_done(request).await?, - Err(e) => { - // If this is a store error propagate it up - // the call stack. - if let SecretImportError::Store(e) = e { - return Err(e); - } else { - // Otherwise warn that there was - // something wrong with the secret. - warn!( - secret_name = secret_name.as_ref(), - error =? e, - "Error while importing a secret" - ) + if secret_name != &SecretName::RecoveryKey { + match self.store.import_secret(secret_name, secret).await { + Ok(_) => self.mark_as_done(request).await?, + Err(e) => { + // If this is a store error propagate it up + // the call stack. + if let SecretImportError::Store(e) = e { + return Err(e); + } else { + // Otherwise warn that there was + // something wrong with the secret. + warn!( + secret_name = secret_name.as_ref(), + error =? e, + "Error while importing a secret" + ) + } } } + } else { + // Skip importing the recovery key here since + // we'll want to check if the public key matches + // to the latest version on the server. We + // instead leave the key in the event and let + // the user import it later. + event.content.secret = secret; } } else { warn!( diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index c7e0759dd..1f69422a0 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -539,7 +539,7 @@ impl ReadOnlyDevice { Ok(()) } - fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> { + pub(crate) fn is_signed_by_device(&self, json: &mut Value) -> Result<(), SignatureError> { let signing_key = self.get_key(DeviceKeyAlgorithm::Ed25519).ok_or(SignatureError::MissingSigningKey)?; @@ -623,9 +623,6 @@ impl PartialEq for ReadOnlyDevice { #[cfg(test)] pub(crate) mod test { - #[cfg(target_arch = "wasm32")] - use wasm_bindgen_test::wasm_bindgen_test; - use matrix_sdk_test::async_test; use std::convert::TryFrom; use ruma::{encryption::DeviceKeys, user_id, DeviceKeyAlgorithm}; diff --git a/crates/matrix-sdk-crypto/src/lib.rs b/crates/matrix-sdk-crypto/src/lib.rs index e387e3fac..9e65a77e2 100644 --- a/crates/matrix-sdk-crypto/src/lib.rs +++ b/crates/matrix-sdk-crypto/src/lib.rs @@ -30,6 +30,9 @@ compile_error!("indexeddb_cryptostore only works for wasm32 target"); +#[cfg(feature = "backups_v1")] +#[cfg_attr(feature = "docs", doc(cfg(backups_v1)))] +pub mod backups; mod error; mod file_encryption; mod gossiping; @@ -72,7 +75,7 @@ pub use matrix_qrcode; pub(crate) use olm::ReadOnlyAccount; pub use olm::{CrossSigningStatus, EncryptionSettings}; pub use requests::{ - IncomingResponse, KeysQueryRequest, OutgoingRequest, OutgoingRequests, + IncomingResponse, KeysBackupRequest, KeysQueryRequest, OutgoingRequest, OutgoingRequests, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest, UploadSigningKeysRequest, }; pub use store::{CrossSigningKeyExport, CryptoStoreError, SecretImportError}; diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index 9d2ad8198..d532fb6cb 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -46,11 +46,14 @@ use ruma::{ secret::request::SecretName, AnyMessageEventContent, AnyRoomEvent, AnyToDeviceEvent, EventContent, }, - DeviceId, DeviceIdBox, DeviceKeyAlgorithm, EventEncryptionAlgorithm, RoomId, UInt, UserId, + DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UInt, + UserId, }; use serde_json::Value; use tracing::{debug, error, info, trace, warn}; +#[cfg(feature = "backups_v1")] +use crate::backups::BackupMachine; #[cfg(feature = "sled_cryptostore")] use crate::store::sled::SledStore; use crate::{ @@ -104,6 +107,9 @@ pub struct OlmMachine { /// State machine handling public user identities and devices, keeping track /// of when a key query needs to be done and handling one. identity_manager: IdentityManager, + /// A state machine that handles creating room key backups. + #[cfg(feature = "backups_v1")] + backup_machine: BackupMachine, } #[cfg(not(tarpaulin_include))] @@ -180,6 +186,9 @@ impl OlmMachine { let identity_manager = IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); + #[cfg(feature = "backups_v1")] + let backup_machine = BackupMachine::new(account.clone(), store.clone(), None); + OlmMachine { user_id, device_id, @@ -191,6 +200,8 @@ impl OlmMachine { verification_machine, key_request_machine, identity_manager, + #[cfg(feature = "backups_v1")] + backup_machine, } } @@ -371,6 +382,10 @@ impl OlmMachine { IncomingResponse::RoomMessage(_) => { self.verification_machine.mark_request_as_sent(request_id); } + IncomingResponse::KeysBackup(_) => { + #[cfg(feature = "backups_v1")] + self.backup_machine.mark_request_as_sent(*request_id).await?; + } }; Ok(()) @@ -651,18 +666,18 @@ impl OlmMachine { Ok(()) } - // #[cfg(test)] - // pub(crate) async fn create_inbound_session( - // &self, - // room_id: &RoomId, - // ) -> OlmResult { - // let (_, session) = self - // .group_session_manager - // .create_outbound_group_session(room_id, EncryptionSettings::default()) - // .await?; + #[cfg(test)] + pub(crate) async fn create_inbound_session( + &self, + room_id: &RoomId, + ) -> OlmResult { + let (_, session) = self + .group_session_manager + .create_outbound_group_session(room_id, EncryptionSettings::default()) + .await?; - // Ok(session) - // } + Ok(session) + } /// Encrypt a room message for the given room. /// @@ -1453,6 +1468,62 @@ impl OlmMachine { ) -> Result { self.store.import_cross_signing_keys(export).await } + + async fn sign_account( + &self, + message: &str, + signatures: &mut BTreeMap>, + ) { + let device_key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()); + let signature = self.account.sign(message).await; + + signatures.entry(self.user_id().to_owned()).or_default().insert(device_key_id, signature); + } + + async fn sign_master( + &self, + message: &str, + signatures: &mut BTreeMap>, + ) -> Result<(), crate::SignatureError> { + let identity = &*self.user_identity.lock().await; + + let master_key: DeviceIdBox = identity + .master_public_key() + .await + .and_then(|m| m.get_first_key().map(|k| k.to_owned())) + .ok_or(crate::SignatureError::MissingSigningKey)? + .into(); + + let device_key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &master_key); + let signature = identity.sign(message).await?; + + signatures.entry(self.user_id().to_owned()).or_default().insert(device_key_id, signature); + + Ok(()) + } + + /// Sign the given message using our device key and if available cross + /// signing master key. + pub async fn sign(&self, message: &str) -> BTreeMap> { + let mut signatures: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new(); + + self.sign_account(message, &mut signatures).await; + + if let Err(e) = self.sign_master(message, &mut signatures).await { + warn!(error =? e, "Couldn't sign the message using the cross signing master key") + } + + signatures + } + + /// Get a reference to the backup related state machine. + /// + /// This state machine can be used to incrementally backup all room keys to + /// the server. + #[cfg(feature = "backups_v1")] + pub fn backup_machine(&self) -> &BackupMachine { + &self.backup_machine + } } #[cfg(test)] diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index 21ed2c606..83a9c6d5e 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -686,6 +686,19 @@ impl ReadOnlyAccount { self.inner.lock().await.sign(string) } + #[cfg(feature = "backups_v1")] + pub(crate) fn is_signed(&self, json: &mut Value) -> Result<(), SignatureError> { + let signing_key = self.identity_keys.ed25519(); + let utility = crate::olm::Utility::new(); + + utility.verify_json( + &self.user_id, + &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()), + signing_key, + json, + ) + } + /// Store the account as a base64 encoded string. /// /// # Arguments @@ -806,7 +819,8 @@ impl ReadOnlyAccount { ) -> Result { let public_key = master_key.get_first_key().ok_or(SignatureError::MissingSigningKey)?.to_string(); - let mut cross_signing_key = master_key.into(); + let mut cross_signing_key: CrossSigningKey = master_key.into(); + cross_signing_key.signatures.clear(); self.sign_cross_signing_key(&mut cross_signing_key).await?; let mut signed_keys = BTreeMap::new(); diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs index 4cc71fd21..8e1927c21 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs @@ -12,7 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, convert::TryFrom, fmt, mem, sync::Arc}; +use std::{ + collections::BTreeMap, + convert::TryFrom, + fmt, mem, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, +}; use matrix_sdk_common::locks::Mutex; pub use olm_rs::{ @@ -39,7 +47,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use zeroize::Zeroizing; -use super::{ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey}; +use super::{BackedUpRoomKey, ExportedGroupSessionKey, ExportedRoomKey, GroupSessionKey}; use crate::error::{EventError, MegolmResult}; // TODO add creation times to the inbound group sessions so we can export @@ -61,7 +69,7 @@ pub struct InboundGroupSession { pub(crate) room_id: Arc, forwarding_chains: Arc>, imported: bool, - backed_up: bool, + backed_up: Arc, } impl InboundGroupSession { @@ -105,7 +113,7 @@ impl InboundGroupSession { room_id: room_id.clone().into(), forwarding_chains: Vec::new().into(), imported: false, - backed_up: false, + backed_up: AtomicBool::new(false).into(), }) } @@ -123,6 +131,25 @@ impl InboundGroupSession { Self::try_from(exported_session.into()) } + #[allow(dead_code)] + fn from_backup( + room_id: &RoomId, + backup: BackedUpRoomKey, + ) -> Result { + let session = OlmInboundGroupSession::new(&backup.session_key.0)?; + let session_id = session.session_id(); + + Self::from_export(ExportedRoomKey { + algorithm: backup.algorithm, + room_id: room_id.to_owned(), + sender_key: backup.sender_key, + session_id, + session_key: backup.session_key, + sender_claimed_keys: backup.sender_claimed_keys, + forwarding_curve25519_key_chain: backup.forwarding_curve25519_key_chain, + }) + } + /// Create a new inbound group session from a forwarded room key content. /// /// # Arguments @@ -157,7 +184,7 @@ impl InboundGroupSession { room_id: content.room_id.clone().into(), forwarding_chains: forwarding_chains.into(), imported: true, - backed_up: false, + backed_up: AtomicBool::new(false).into(), }) } @@ -177,7 +204,7 @@ impl InboundGroupSession { room_id: (&*self.room_id).clone(), forwarding_chains: self.forwarding_key_chain().to_vec(), imported: self.imported, - backed_up: self.backed_up, + backed_up: self.backed_up(), history_visibility: self.history_visibility.as_ref().clone(), } } @@ -195,6 +222,21 @@ impl InboundGroupSession { &self.sender_key } + /// Has the session been backed up to the server. + pub fn backed_up(&self) -> bool { + self.backed_up.load(SeqCst) + } + + /// Reset the backup state of the inbound group session. + pub(crate) fn reset_backup_state(&self) { + self.backed_up.store(false, SeqCst) + } + + #[cfg(feature = "backups_v1")] + pub(crate) fn mark_as_backed_up(&self) { + self.backed_up.store(true, SeqCst) + } + /// Get the map of signing keys this session was received from. pub fn signing_keys(&self) -> &BTreeMap { &self.signing_keys @@ -256,7 +298,7 @@ impl InboundGroupSession { signing_keys: pickle.signing_key.into(), room_id: pickle.room_id.into(), forwarding_chains: pickle.forwarding_chains.into(), - backed_up: pickle.backed_up, + backed_up: AtomicBool::from(pickle.backed_up).into(), imported: pickle.imported, }) } @@ -291,6 +333,11 @@ impl InboundGroupSession { self.inner.lock().await.decrypt(message) } + #[cfg(feature = "backups_v1")] + pub(crate) async fn to_backup(&self) -> BackedUpRoomKey { + self.export().await.into() + } + /// Decrypt an event from a room timeline. /// /// # Arguments @@ -413,16 +460,16 @@ impl TryFrom for InboundGroupSession { let first_known_index = session.first_known_index(); Ok(InboundGroupSession { - inner: Arc::new(Mutex::new(session)), + inner: Mutex::new(session).into(), session_id: key.session_id.into(), sender_key: key.sender_key.into(), history_visibility: None.into(), first_known_index, - signing_keys: Arc::new(key.sender_claimed_keys), - room_id: Arc::new(key.room_id), - forwarding_chains: Arc::new(key.forwarding_curve25519_key_chain), + signing_keys: key.sender_claimed_keys.into(), + room_id: key.room_id.into(), + forwarding_chains: key.forwarding_curve25519_key_chain.into(), imported: true, - backed_up: false, + backed_up: AtomicBool::from(false).into(), }) } } diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs index 23328c342..400bc8160 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/mod.rs @@ -43,7 +43,7 @@ pub struct GroupSessionKey(pub String); #[zeroize(drop)] pub struct ExportedGroupSessionKey(pub String); -/// An exported version of a `InboundGroupSession` +/// An exported version of an `InboundGroupSession` /// /// This can be used to share the `InboundGroupSession` in an exported file. #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] @@ -71,6 +71,28 @@ pub struct ExportedRoomKey { pub forwarding_curve25519_key_chain: Vec, } +/// A backed up version of an `InboundGroupSession` +/// +/// This can be used to backup the `InboundGroupSession` to the server. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct BackedUpRoomKey { + /// The encryption algorithm that the session uses. + pub algorithm: EventEncryptionAlgorithm, + + /// The Curve25519 key of the device which initiated the session originally. + pub sender_key: String, + + /// The key for the session. + pub session_key: ExportedGroupSessionKey, + + /// The Ed25519 key of the device which initiated the session originally. + pub sender_claimed_keys: BTreeMap, + + /// Chain of Curve25519 keys through which this session was forwarded, via + /// m.forwarded_room_key events. + pub forwarding_curve25519_key_chain: Vec, +} + impl TryInto for ExportedRoomKey { type Error = (); @@ -104,6 +126,18 @@ impl TryInto for ExportedRoomKey { } } +impl From for BackedUpRoomKey { + fn from(k: ExportedRoomKey) -> Self { + Self { + algorithm: k.algorithm, + sender_key: k.sender_key, + session_key: k.session_key, + sender_claimed_keys: k.sender_claimed_keys, + forwarding_curve25519_key_chain: k.forwarding_curve25519_key_chain, + } + } +} + impl From for ExportedRoomKey { /// Convert the content of a forwarded room key into a exported room key. fn from(forwarded_key: ToDeviceForwardedRoomKeyEventContent) -> Self { @@ -125,6 +159,9 @@ impl From for ExportedRoomKey { #[cfg(test)] mod test { + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::wasm_bindgen_test; + use matrix_sdk_test::async_test; use std::{ sync::Arc, time::{Duration, Instant}, diff --git a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs index 8b045d13c..be7433428 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs @@ -458,6 +458,17 @@ impl PrivateCrossSigningIdentity { Ok(SignatureUploadRequest::new(signed_keys)) } + pub(crate) async fn sign(&self, message: &str) -> Result { + Ok(self + .master_key + .lock() + .await + .as_ref() + .ok_or(SignatureError::MissingSigningKey)? + .sign(message) + .await) + } + /// Create a new identity for the given Olm Account. /// /// Returns the new identity, the upload signing keys request and a @@ -770,6 +781,9 @@ mod test { let master = user_signing.sign_user(&bob_public).await.unwrap(); + let num_signatures: usize = master.signatures.iter().map(|(_, u)| u.len()).sum(); + assert_eq!(num_signatures, 1, "We're only uploading our own signature"); + bob_public.master_key = master.into(); user_signing.public_key.verify_master_key(bob_public.master_key()).unwrap(); diff --git a/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs b/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs index 258cad02f..964da3554 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs @@ -155,6 +155,10 @@ impl MasterSigning { Ok(Self { inner, public_key: pickle.public_key.into() }) } + pub async fn sign(&self, message: &str) -> String { + self.inner.sign(message).await.0 + } + pub async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { let subkey_without_signatures = json!({ "user_id": subkey.user_id.clone(), @@ -164,7 +168,7 @@ impl MasterSigning { let message = serde_json::to_string(&subkey_without_signatures) .expect("Can't serialize cross signing subkey"); - let signature = self.inner.sign(&message).await; + let signature = self.sign(&message).await; subkey .signatures @@ -176,7 +180,7 @@ impl MasterSigning { self.inner.public_key().as_str().into(), ) .to_string(), - signature.0, + signature, ); } } @@ -203,10 +207,10 @@ impl UserSigning { &self, user: &ReadOnlyUserIdentity, ) -> Result { - let signature = self.sign_user_helper(user).await?; + let signatures = self.sign_user_helper(user).await?; let mut master_key: CrossSigningKey = user.master_key().to_owned().into(); - master_key.signatures.extend(signature); + master_key.signatures = signatures; Ok(master_key) } diff --git a/crates/matrix-sdk-crypto/src/requests.rs b/crates/matrix-sdk-crypto/src/requests.rs index 39005734d..0507a391e 100644 --- a/crates/matrix-sdk-crypto/src/requests.rs +++ b/crates/matrix-sdk-crypto/src/requests.rs @@ -17,6 +17,7 @@ use std::{collections::BTreeMap, iter, sync::Arc, time::Duration}; use matrix_sdk_common::uuid::Uuid; use ruma::{ api::client::r0::{ + backup::{add_backup_keys::Response as KeysBackupResponse, RoomKeyBackup}, keys::{ claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, get_keys::Response as KeysQueryResponse, @@ -197,6 +198,8 @@ pub enum OutgoingRequests { /// A room message request, usually for sending in-room interactive /// verification events. RoomMessage(RoomMessageRequest), + /// A request that will back up a batch of room keys to the server. + KeysBackup(KeysBackupRequest), } #[cfg(test)] @@ -215,6 +218,12 @@ impl From for OutgoingRequests { } } +impl From for OutgoingRequests { + fn from(r: KeysBackupRequest) -> Self { + OutgoingRequests::KeysBackup(r) + } +} + impl From for OutgoingRequests { fn from(r: KeysClaimRequest) -> Self { Self::KeysClaim(r) @@ -278,6 +287,8 @@ pub enum IncomingResponse<'a> { SignatureUpload(&'a SignatureUploadResponse), /// A room message response, usually for interactive verifications. RoomMessage(&'a RoomMessageResponse), + /// Response for the server-side room key backup request. + KeysBackup(&'a KeysBackupResponse), } impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> { @@ -286,6 +297,12 @@ impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> { } } +impl<'a> From<&'a KeysBackupResponse> for IncomingResponse<'a> { + fn from(response: &'a KeysBackupResponse) -> Self { + IncomingResponse::KeysBackup(response) + } +} + impl<'a> From<&'a KeysQueryResponse> for IncomingResponse<'a> { fn from(response: &'a KeysQueryResponse) -> Self { IncomingResponse::KeysQuery(response) @@ -356,6 +373,16 @@ pub struct RoomMessageRequest { pub content: AnyMessageEventContent, } +/// A request that will back up a batch of room keys to the server. +#[derive(Clone, Debug)] +pub struct KeysBackupRequest { + /// The backup version that these room keys should be part of. + pub version: String, + /// The map from room id to a backed up room key that we're going to upload + /// to the server. + pub rooms: BTreeMap, +} + /// An enum over the different outgoing verification based requests. #[derive(Clone, Debug)] pub enum OutgoingVerificationRequest { diff --git a/crates/matrix-sdk-crypto/src/store/caches.rs b/crates/matrix-sdk-crypto/src/store/caches.rs index 4b9c631f0..7fac048bf 100644 --- a/crates/matrix-sdk-crypto/src/store/caches.rs +++ b/crates/matrix-sdk-crypto/src/store/caches.rs @@ -111,6 +111,11 @@ impl GroupSessionStore { .collect() } + /// Get the number of `InboundGroupSession`s we have. + pub fn count(&self) -> usize { + self.entries.iter().map(|d| d.value().values().map(|t| t.len()).sum::()).sum() + } + /// Get a inbound group session from our store. /// /// # Arguments diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index ebbd1689d..6d20a0865 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -23,7 +23,8 @@ use ruma::{DeviceId, DeviceIdBox, RoomId, UserId}; use super::{ caches::{DeviceStore, GroupSessionStore, SessionStore}, - Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, + BackupKeys, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, RoomKeyCounts, + Session, }; use crate::{ gossiping::{GossipRequest, SecretInfo}, @@ -163,6 +164,34 @@ impl CryptoStore for MemoryStore { Ok(self.inbound_group_sessions.get_all()) } + async fn inbound_group_session_counts(&self) -> Result { + let backed_up = + self.get_inbound_group_sessions().await?.into_iter().filter(|s| s.backed_up()).count(); + + Ok(RoomKeyCounts { total: self.inbound_group_sessions.count(), backed_up }) + } + + async fn inbound_group_sessions_for_backup( + &self, + limit: usize, + ) -> Result> { + Ok(self + .get_inbound_group_sessions() + .await? + .into_iter() + .filter(|s| !s.backed_up()) + .take(limit) + .collect()) + } + + async fn reset_backup_state(&self) -> Result<()> { + for session in self.get_inbound_group_sessions().await? { + session.reset_backup_state() + } + + Ok(()) + } + async fn get_outbound_group_sessions( &self, _: &RoomId, @@ -271,6 +300,10 @@ impl CryptoStore for MemoryStore { Ok(()) } + + async fn load_backup_keys(&self) -> Result { + Ok(BackupKeys::default()) + } } #[cfg(test)] diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 0faac01d8..6e45baf29 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -107,11 +107,15 @@ pub(crate) struct Store { verification_machine: VerificationMachine, } -#[derive(Clone, Debug, Default)] +#[derive(Default, Debug)] #[allow(missing_docs)] pub struct Changes { pub account: Option, pub private_identity: Option, + #[cfg(feature = "backups_v1")] + pub backup_version: Option, + #[cfg(feature = "backups_v1")] + pub recovery_key: Option, pub sessions: Vec, pub message_hashes: Vec, pub inbound_group_sessions: Vec, @@ -170,6 +174,26 @@ impl DeviceChanges { } } +/// Struct holding info about how many room keys the store has. +#[derive(Debug, Clone, Default)] +pub struct RoomKeyCounts { + /// The total number of room keys the store has. + pub total: usize, + /// The number of backed up room keys the store has. + pub backed_up: usize, +} + +/// Stored versions of the backup keys. +#[derive(Default, Debug)] +pub struct BackupKeys { + /// The recovery key, the one used to decrypt backed up room keys. + #[cfg(feature = "backups_v1")] + pub recovery_key: Option, + /// The version that we are using for backups. + #[cfg(feature = "backups_v1")] + pub backup_version: Option, +} + /// A struct containing private cross signing keys that can be backed up or /// uploaded to the secret store. #[derive(Zeroize)] @@ -431,7 +455,18 @@ impl Store { | SecretName::CrossSigningSelfSigningKey => { self.identity.lock().await.export_secret(secret_name).await } - SecretName::RecoveryKey => None, + SecretName::RecoveryKey => { + #[cfg(feature = "backups_v1")] + if let Some(key) = self.load_backup_keys().await.unwrap().recovery_key { + let exported = key.to_base64(); + Some(exported) + } else { + None + } + + #[cfg(not(feature = "backups_v1"))] + None + } name => { warn!(secret =? name, "Unknown secret was requested"); None @@ -496,7 +531,12 @@ impl Store { self.save_changes(changes).await?; } } - SecretName::RecoveryKey => (), + SecretName::RecoveryKey => { + // We don't import the recovery key here since we'll want to + // check if the public key matches to the latest version on the + // server. We instead leave the key in the event and let the + // user import it later. + } name => { warn!(secret =? name, "Tried to import an unknown secret"); } @@ -630,6 +670,22 @@ pub trait CryptoStore: AsyncTraitDeps { /// Get all the inbound group sessions we have stored. async fn get_inbound_group_sessions(&self) -> Result>; + /// Get the number inbound group sessions we have and how many of them are + /// backed up. + async fn inbound_group_session_counts(&self) -> Result; + + /// Get all the inbound group sessions we have not backed up yet. + async fn inbound_group_sessions_for_backup( + &self, + limit: usize, + ) -> Result>; + + /// Reset the backup state of all the stored inbound group sessions. + async fn reset_backup_state(&self) -> Result<()>; + + /// Get the backup keys we have stored. + async fn load_backup_keys(&self) -> Result; + /// Get the outbound group sessions we have stored that is used for the /// given room. async fn get_outbound_group_sessions( diff --git a/crates/matrix-sdk-crypto/src/store/sled.rs b/crates/matrix-sdk-crypto/src/store/sled.rs index 0e299b949..586a9e590 100644 --- a/crates/matrix-sdk-crypto/src/store/sled.rs +++ b/crates/matrix-sdk-crypto/src/store/sled.rs @@ -29,14 +29,14 @@ use ruma::{ pub use sled::Error; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, - Config, Db, Transactional, Tree, + Config, Db, IVec, Transactional, Tree, }; use tracing::trace; use uuid::Uuid; use super::{ - caches::SessionStore, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, PickleKey, - ReadOnlyAccount, Result, Session, + caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, InboundGroupSession, + PickleKey, ReadOnlyAccount, Result, RoomKeyCounts, Session, }; use crate::{ gossiping::{GossipRequest, SecretInfo}, @@ -205,9 +205,46 @@ impl SledStore { self.account_info.read().unwrap().clone() } - fn upgrade_database(db: &Db) -> Result<()> { - let version = db - .get("version")? + async fn reset_backup_state(&self) -> Result<()> { + let mut pickles: Vec<(IVec, PickledInboundGroupSession)> = self + .inbound_group_sessions + .iter() + .map(|p| { + let item = p?; + Ok(( + item.0, + serde_json::from_slice(&item.1).map_err(CryptoStoreError::Serialization)?, + )) + }) + .collect::>()?; + + for (_, pickle) in &mut pickles { + pickle.backed_up = false; + } + + let ret: Result<(), TransactionError> = + self.inbound_group_sessions.transaction(|inbound_sessions| { + for (key, pickle) in &pickles { + inbound_sessions.insert( + key, + serde_json::to_vec(&pickle).map_err(ConflictableTransactionError::Abort)?, + )?; + } + + Ok(()) + }); + + ret?; + + self.inner.flush_async().await?; + + Ok(()) + } + + fn upgrade(&self) -> Result<()> { + let version = self + .inner + .get("store_version")? .map(|v| { let (version_bytes, _) = v.split_at(std::mem::size_of::()); u8::from_be_bytes(version_bytes.try_into().unwrap_or_default()) @@ -225,17 +262,17 @@ impl SledStore { if version == 0 { // We changed the schema but migrating this isn't important since we // rotate the group sessions relatively often anyways so we just - // drop it. - db.drop_tree("outbound_group_sessions")?; + // clear the tree. + self.outbound_group_sessions.clear()?; } - db.insert("version", DATABASE_VERSION.to_be_bytes().as_ref())?; + self.inner.insert("store_version", DATABASE_VERSION.to_be_bytes().as_ref())?; + self.inner.flush()?; Ok(()) } fn open_helper(db: Db, path: Option, passphrase: Option<&str>) -> Result { - Self::upgrade_database(&db)?; let account = db.open_tree("account")?; let private_identity = db.open_tree("private_identity")?; @@ -263,7 +300,7 @@ impl SledStore { .expect("Can't create default pickle key") }; - Ok(Self { + let database = Self { account_info: RwLock::new(None).into(), path, inner: db, @@ -283,7 +320,11 @@ impl SledStore { tracked_users, olm_hashes, identities, - }) + }; + + database.upgrade()?; + + Ok(database) } fn get_or_create_pickle_key(passphrase: &str, database: &Db) -> Result { @@ -361,6 +402,9 @@ impl SledStore { None }; + #[cfg(feature = "backups_v1")] + let recovery_key_pickle = changes.recovery_key.map(|r| r.pickle(self.get_pickle_key())); + let device_changes = changes.devices; let mut session_changes = HashMap::new(); @@ -399,6 +443,8 @@ impl SledStore { let identity_changes = changes.identities; let olm_hashes = changes.message_hashes; let key_requests = changes.key_requests; + #[cfg(feature = "backups_v1")] + let backup_version = changes.backup_version; let ret: Result<(), TransactionError> = ( &self.account, @@ -441,6 +487,22 @@ impl SledStore { )?; } + #[cfg(feature = "backups_v1")] + if let Some(r) = &recovery_key_pickle { + account.insert( + "recovery_key_v1".encode(), + serde_json::to_vec(r).map_err(ConflictableTransactionError::Abort)?, + )?; + } + + #[cfg(feature = "backups_v1")] + if let Some(b) = &backup_version { + account.insert( + "backup_version_v1".encode(), + serde_json::to_vec(b).map_err(ConflictableTransactionError::Abort)?, + )?; + } + for device in device_changes.new.iter().chain(&device_changes.changed) { let key = (device.user_id().as_str(), device.device_id().as_str()).encode(); let device = serde_json::to_vec(&device) @@ -655,6 +717,57 @@ impl CryptoStore for SledStore { .collect()) } + async fn inbound_group_session_counts(&self) -> Result { + let pickles: Vec = self + .inbound_group_sessions + .iter() + .map(|p| { + let item = p?; + serde_json::from_slice(&item.1).map_err(CryptoStoreError::Serialization) + }) + .collect::>()?; + + let total = pickles.len(); + let backed_up = pickles.into_iter().filter(|p| p.backed_up).count(); + + Ok(RoomKeyCounts { total, backed_up }) + } + + async fn inbound_group_sessions_for_backup( + &self, + limit: usize, + ) -> Result> { + let pickles: Vec = self + .inbound_group_sessions + .iter() + .map(|p| { + let item = p?; + serde_json::from_slice(&item.1).map_err(CryptoStoreError::from) + }) + .filter_map(|p: Result| match p { + Ok(p) => { + if !p.backed_up { + Some( + InboundGroupSession::from_pickle(p, self.get_pickle_mode()) + .map_err(CryptoStoreError::from), + ) + } else { + None + } + } + + Err(p) => Some(Err(p)), + }) + .take(limit) + .collect::>()?; + + Ok(pickles) + } + + async fn reset_backup_state(&self) -> Result<()> { + self.reset_backup_state().await + } + async fn get_outbound_group_sessions( &self, room_id: &RoomId, @@ -670,14 +783,14 @@ impl CryptoStore for SledStore { !self.users_for_key_query_cache.is_empty() } - fn tracked_users(&self) -> HashSet { - self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect() - } - fn users_for_key_query(&self) -> HashSet { self.users_for_key_query_cache.iter().map(|u| u.clone()).collect() } + fn tracked_users(&self) -> HashSet { + self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect() + } + async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { let already_added = self.tracked_users_cache.insert(user.clone()); @@ -796,6 +909,32 @@ impl CryptoStore for SledStore { Ok(()) } + + async fn load_backup_keys(&self) -> Result { + let version = self + .account + .get("backup_version_v1".encode())? + .map(|v| serde_json::from_slice(&v)) + .transpose()?; + + #[cfg(feature = "backups_v1")] + let recovery_key = { + self.account + .get("recovery_key_v1".encode())? + .map(|p| serde_json::from_slice(&p)) + .transpose()? + .map(|p| { + crate::backups::RecoveryKey::from_pickle(p, self.get_pickle_key()) + .map_err(|_| CryptoStoreError::UnpicklingError) + }) + .transpose()? + }; + + #[cfg(not(feature = "backups_v1"))] + let recovery_key = None; + + Ok(BackupKeys { backup_version: version, recovery_key }) + } } #[cfg(test)] diff --git a/crates/matrix-sdk-crypto/src/verification/event_enums.rs b/crates/matrix-sdk-crypto/src/verification/event_enums.rs index aa2890d53..0d8e324ae 100644 --- a/crates/matrix-sdk-crypto/src/verification/event_enums.rs +++ b/crates/matrix-sdk-crypto/src/verification/event_enums.rs @@ -775,11 +775,12 @@ impl TryFrom for OutgoingContent { fn try_from(value: OutgoingRequest) -> Result { match value.request() { - crate::OutgoingRequests::KeysUpload(_) => Err("Invalid request type".to_owned()), - crate::OutgoingRequests::KeysQuery(_) => Err("Invalid request type".to_owned()), + crate::OutgoingRequests::KeysUpload(_) + | crate::OutgoingRequests::KeysQuery(_) + | crate::OutgoingRequests::KeysBackup(_) + | crate::OutgoingRequests::SignatureUpload(_) + | crate::OutgoingRequests::KeysClaim(_) => Err("Invalid request type".to_owned()), crate::OutgoingRequests::ToDeviceRequest(r) => Self::try_from(r.clone()), - crate::OutgoingRequests::SignatureUpload(_) => Err("Invalid request type".to_owned()), - crate::OutgoingRequests::KeysClaim(_) => Err("Invalid request type".to_owned()), crate::OutgoingRequests::RoomMessage(r) => Ok(Self::from(r.clone())), } } diff --git a/crates/matrix-sdk-crypto/src/verification/mod.rs b/crates/matrix-sdk-crypto/src/verification/mod.rs index 3979c5938..f006c9836 100644 --- a/crates/matrix-sdk-crypto/src/verification/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/mod.rs @@ -535,7 +535,7 @@ impl IdentitiesBeingVerified { }; if should_request_secrets { - let secret_requests = self.request_missing_secrets().await; + let secret_requests = self.request_missing_secrets().await?; changes.key_requests = secret_requests; } @@ -547,9 +547,16 @@ impl IdentitiesBeingVerified { .unwrap_or(VerificationResult::Ok)) } - async fn request_missing_secrets(&self) -> Vec { - let secrets = self.private_identity.get_missing_secrets().await; - GossipMachine::request_missing_secrets(self.user_id(), secrets) + async fn request_missing_secrets(&self) -> Result, CryptoStoreError> { + #[allow(unused_mut)] + let mut secrets = self.private_identity.get_missing_secrets().await; + + #[cfg(feature = "backups_v1")] + if self.store.inner.load_backup_keys().await?.recovery_key.is_none() { + secrets.push(ruma::events::secret::request::SecretName::RecoveryKey); + } + + Ok(GossipMachine::request_missing_secrets(self.user_id(), secrets)) } async fn mark_identity_as_verified( @@ -667,9 +674,6 @@ impl IdentitiesBeingVerified { #[cfg(test)] pub(crate) mod test { - #[cfg(target_arch = "wasm32")] - use wasm_bindgen_test::wasm_bindgen_test; - use matrix_sdk_test::async_test; use std::convert::TryInto; use ruma::{ diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index fa63b1898..b818c724b 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -110,7 +110,7 @@ features = ["wasm-bindgen"] [dev-dependencies] anyhow = "1.0" dirs = "3.0.2" -futures = { version = "0.3.15", default-features = false } +futures = { version = "0.3.15", default-features = false, features = ["executor"] } lazy_static = "1.4.0" matches = "0.1.8" matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" } diff --git a/crates/matrix-sdk/src/client.rs b/crates/matrix-sdk/src/client.rs index 2842f8679..39e3869c8 100644 --- a/crates/matrix-sdk/src/client.rs +++ b/crates/matrix-sdk/src/client.rs @@ -26,15 +26,14 @@ use std::{ use anymap2::any::CloneAnySendSync; use dashmap::DashMap; use futures_core::stream::Stream; -use futures_timer::Delay as sleep; use matrix_sdk_base::{ - deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse}, + deserialized_responses::SyncResponse, media::{MediaEventContent, MediaFormat, MediaRequest, MediaThumbnailSize, MediaType}, BaseClient, Session, Store, }; use matrix_sdk_common::{ instant::{Duration, Instant}, - locks::{Mutex, RwLock}, + locks::{Mutex, RwLock, RwLockReadGuard}, }; use mime::{self, Mime}; use ruma::{ @@ -109,27 +108,31 @@ pub enum LoopCtrl { /// All of the state is held in an `Arc` so the `Client` can be cloned freely. #[derive(Clone)] pub struct Client { + pub(crate) inner: Arc, +} + +pub(crate) struct ClientInner { /// The URL of the homeserver to connect to. homeserver: Arc>, /// The underlying HTTP client. http_client: HttpClient, /// User session data. - pub(crate) base_client: BaseClient, + base_client: BaseClient, /// Locks making sure we only have one group session sharing request in /// flight per room. #[cfg(feature = "encryption")] - pub(crate) group_session_locks: Arc>>>, + pub(crate) group_session_locks: DashMap>>, #[cfg(feature = "encryption")] /// Lock making sure we're only doing one key claim request at a time. - pub(crate) key_claim_lock: Arc>, - pub(crate) members_request_locks: Arc>>>, - pub(crate) typing_notice_times: Arc>, + pub(crate) key_claim_lock: Mutex<()>, + pub(crate) members_request_locks: DashMap>>, + pub(crate) typing_notice_times: DashMap, /// Event handlers. See `register_event_handler`. - pub(crate) event_handlers: Arc>, + event_handlers: RwLock, /// Custom event handler context. See `register_event_handler_context`. - pub(crate) event_handler_data: Arc>, + event_handler_data: StdRwLock, /// Notification handlers. See `register_notification_handler`. - notification_handlers: Arc>>, + notification_handlers: RwLock>, /// Whether the client should operate in application service style mode. /// This is low-level functionality. For an high-level API check the /// `matrix_sdk_appservice` crate. @@ -142,7 +145,7 @@ pub struct Client { /// synchronization, e.g. if we send out a request to create a room, we can /// wait for the sync to get the data to fetch a room object from the state /// store. - pub(crate) sync_beat: Arc, + pub(crate) sync_beat: event_listener::Event, } #[cfg(not(tarpaulin_include))] @@ -186,7 +189,7 @@ impl Client { let http_client = HttpClient::new(client, homeserver.clone(), session, config.request_config); - Ok(Self { + let inner = Arc::new(ClientInner { homeserver, http_client, base_client, @@ -201,8 +204,10 @@ impl Client { notification_handlers: Default::default(), appservice_mode: config.appservice_mode, use_discovery_response: config.use_discovery_response, - sync_beat: event_listener::Event::new().into(), - }) + sync_beat: event_listener::Event::new(), + }); + + Ok(Self { inner }) } /// Create a new [`Client`] using homeserver auto discovery. @@ -226,12 +231,12 @@ impl Client { /// // First let's try to construct an user id, presumably from user input. /// let alice = UserId::try_from("@alice:example.org")?; /// - /// // Now let's try to discover the homeserver and create client object. + /// // Now let's try to discover the homeserver and create a client object. /// let client = Client::new_from_user_id(&alice).await?; /// /// // Finally let's try to login. /// client.login(alice, "password", None, None).await?; - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` /// /// [spec]: https://spec.matrix.org/unstable/client-server-api/#well-known-uri @@ -268,6 +273,24 @@ impl Client { Ok(client) } + pub(crate) fn base_client(&self) -> &BaseClient { + &self.inner.base_client + } + + #[cfg(feature = "encryption")] + pub(crate) async fn olm_machine(&self) -> Option { + self.base_client().olm_machine().await + } + + #[cfg(feature = "encryption")] + pub(crate) async fn mark_request_as_sent( + &self, + request_id: &matrix_sdk_base::uuid::Uuid, + response: impl Into>, + ) -> Result<(), matrix_sdk_base::Error> { + self.base_client().mark_request_as_sent(request_id, response).await + } + fn homeserver_from_user_id(user_id: &UserId) -> Result { let homeserver = format!("https://{}", user_id.server_name()); #[allow(unused_mut)] @@ -290,7 +313,7 @@ impl Client { /// /// * `homeserver_url` - The new URL to use. pub async fn set_homeserver(&self, homeserver_url: Url) { - let mut homeserver = self.homeserver.write().await; + let mut homeserver = self.inner.homeserver.write().await; *homeserver = homeserver_url; } @@ -324,23 +347,23 @@ impl Client { /// Is the client logged in. pub async fn logged_in(&self) -> bool { - self.base_client.logged_in().await + self.inner.base_client.logged_in().await } /// The Homeserver of the client. pub async fn homeserver(&self) -> Url { - self.homeserver.read().await.clone() + self.inner.homeserver.read().await.clone() } /// Get the user id of the current owner of the client. pub async fn user_id(&self) -> Option { - let session = self.base_client.session().read().await; + let session = self.inner.base_client.session().read().await; session.as_ref().cloned().map(|s| s.user_id) } /// Get the device id that identifies the current session. pub async fn device_id(&self) -> Option { - let session = self.base_client.session().read().await; + let session = self.inner.base_client.session().read().await; session.as_ref().map(|s| s.device_id.clone()) } @@ -351,7 +374,7 @@ impl Client { /// Can be used with [`Client::restore_login`] to restore a previously /// logged in session. pub async fn session(&self) -> Option { - self.base_client.session().read().await.clone() + self.inner.base_client.session().read().await.clone() } /// Fetches the display name of the owner of the client. @@ -469,7 +492,7 @@ impl Client { /// Get a reference to the store. pub fn store(&self) -> &Store { - self.base_client.store() + self.inner.base_client.store() } /// Sets the mxc avatar url of the client's owner. The avatar gets unset if @@ -611,32 +634,41 @@ impl Client { ::Output: EventHandlerResult, { let event_type = H::ID.1; - self.event_handlers.write().await.entry(H::ID).or_default().push(Box::new(move |data| { - let maybe_fut = serde_json::from_str(data.raw.get()) - .map(|ev| handler.clone().handle_event(ev, data)); + self.inner.event_handlers.write().await.entry(H::ID).or_default().push(Box::new( + move |data| { + let maybe_fut = serde_json::from_str(data.raw.get()) + .map(|ev| handler.clone().handle_event(ev, data)); - Box::pin(async move { - match maybe_fut { - Ok(Some(fut)) => { - fut.await.print_error(event_type); - } - Ok(None) => { - error!("Event handler for {} has an invalid context argument", event_type); - } - Err(e) => { - warn!( - "Failed to deserialize `{}` event, skipping event handler.\n\ + Box::pin(async move { + match maybe_fut { + Ok(Some(fut)) => { + fut.await.print_error(event_type); + } + Ok(None) => { + error!( + "Event handler for {} has an invalid context argument", + event_type + ); + } + Err(e) => { + warn!( + "Failed to deserialize `{}` event, skipping event handler.\n\ Deserialization error: {}", - event_type, e, - ); + event_type, e, + ); + } } - } - }) - })); + }) + }, + )); self } + pub(crate) async fn event_handlers(&self) -> RwLockReadGuard<'_, EventHandlerMap> { + self.inner.event_handlers.read().await + } + /// Add an arbitrary value for use as event handler context. /// /// The value can be obtained in an event handler by adding an argument of @@ -678,10 +710,18 @@ impl Client { where T: Clone + Send + Sync + 'static, { - self.event_handler_data.write().unwrap().insert(ctx); + self.inner.event_handler_data.write().unwrap().insert(ctx); self } + pub(crate) fn event_handler_context(&self) -> Option + where + T: Clone + Send + Sync + 'static, + { + let map = self.inner.event_handler_data.read().unwrap(); + map.get::().cloned() + } + /// Register a handler for a notification. /// /// Similar to [`Client::register_event_handler`], but only allows functions @@ -692,13 +732,19 @@ impl Client { H: Fn(Notification, room::Room, Client) -> Fut + Send + Sync + 'static, Fut: Future + Send + 'static, { - self.notification_handlers.write().await.push(Box::new( + self.inner.notification_handlers.write().await.push(Box::new( move |notification, room, client| Box::pin((handler)(notification, room, client)), )); self } + pub(crate) async fn notification_handlers( + &self, + ) -> RwLockReadGuard<'_, Vec> { + self.inner.notification_handlers.read().await + } + /// Get all the rooms the client knows about. /// /// This will return the list of joined, invited, and left rooms. @@ -849,7 +895,7 @@ impl Client { /// "Logged in as {}, got device_id {} and access_token {}", /// user, response.device_id, response.access_token /// ); - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` /// /// [`restore_login`]: #method.restore_login @@ -1166,7 +1212,7 @@ impl Client { /// /// * `response` - A successful login response. async fn receive_login_response(&self, response: &login::Response) -> Result<()> { - if self.use_discovery_response { + if self.inner.use_discovery_response { if let Some(well_known) = &response.well_known { if let Ok(homeserver) = Url::parse(&well_known.homeserver.base_url) { self.set_homeserver(homeserver).await; @@ -1174,7 +1220,7 @@ impl Client { } } - self.base_client.receive_login_response(response).await?; + self.inner.base_client.receive_login_response(response).await?; Ok(()) } @@ -1192,9 +1238,52 @@ impl Client { /// * `session` - A session that the user already has from a /// previous login call. /// + /// # Examples + /// + /// ```no_run + /// use matrix_sdk::{Client, Session, ruma::{DeviceIdBox, user_id}}; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # block_on(async { + /// + /// let homeserver = Url::parse("http://example.com")?; + /// let client = Client::new(homeserver).await?; + /// + /// let session = Session { + /// access_token: "My-Token".to_owned(), + /// user_id: user_id!("@example:localhost"), + /// device_id: DeviceIdBox::from("MYDEVICEID"), + /// }; + /// + /// client.restore_login(session).await?; + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + /// + /// The `Session` object can also be created from the response the + /// [`Client::login()`] method returns: + /// + /// ```no_run + /// use matrix_sdk::{Client, Session, ruma::{DeviceIdBox, user_id}}; + /// # use url::Url; + /// # use futures::executor::block_on; + /// # block_on(async { + /// + /// let homeserver = Url::parse("http://example.com")?; + /// let client = Client::new(homeserver).await?; + /// + /// let session: Session = client + /// .login("example", "my-password", None, None) + /// .await? + /// .into(); + /// + /// // Persist the `Session` so it can later be used to restore the login. + /// client.restore_login(session).await?; + /// # anyhow::Result::<()>::Ok(()) }); + /// ``` + /// /// [`login`]: #method.login pub async fn restore_login(&self, session: Session) -> Result<()> { - Ok(self.base_client.restore_login(session).await?) + Ok(self.inner.base_client.restore_login(session).await?) } /// Register a user to the server. @@ -1241,8 +1330,8 @@ impl Client { let homeserver = self.homeserver().await; info!("Registering to {}", homeserver); - let config = if self.appservice_mode { - Some(self.http_client.request_config.force_auth()) + let config = if self.inner.appservice_mode { + Some(self.inner.http_client.request_config.force_auth()) } else { None }; @@ -1307,14 +1396,14 @@ impl Client { filter_name: &str, definition: FilterDefinition<'_>, ) -> Result { - if let Some(filter) = self.base_client.get_filter(filter_name).await? { + if let Some(filter) = self.inner.base_client.get_filter(filter_name).await? { Ok(filter) } else { let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; let request = FilterUploadRequest::new(&user_id, definition); let response = self.send(request, None).await?; - self.base_client.receive_filter_upload(filter_name, &response).await?; + self.inner.base_client.receive_filter_upload(filter_name, &response).await?; Ok(response.filter_id) } @@ -1468,7 +1557,7 @@ impl Client { /// for room in response.chunk { /// println!("Found room {:?}", room); /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn public_rooms_filtered( &self, @@ -1526,8 +1615,8 @@ impl Client { content_type: Some(content_type.essence_str()), }); - let request_config = self.http_client.request_config.timeout(timeout); - Ok(self.http_client.upload(request, Some(request_config)).await?) + let request_config = self.inner.http_client.request_config.timeout(timeout); + Ok(self.inner.http_client.upload(request, Some(request_config)).await?) } /// Send an arbitrary request to the server, without updating client state. @@ -1568,7 +1657,7 @@ impl Client { /// /// // Check the corresponding Response struct to find out what types are /// // returned - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn send( &self, @@ -1579,7 +1668,7 @@ impl Client { Request: OutgoingRequest + Debug, HttpError: From>, { - Ok(self.http_client.send(request, config).await?) + Ok(self.inner.http_client.send(request, config).await?) } /// Get information of all our own devices. @@ -1603,7 +1692,7 @@ impl Client { /// device.display_name.as_deref().unwrap_or("") /// ); /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn devices(&self) -> HttpResult { let request = get_devices::Request::new(); @@ -1656,7 +1745,7 @@ impl Client { /// .await?; /// } /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); pub async fn delete_devices( &self, devices: &[DeviceIdBox], @@ -1668,135 +1757,6 @@ impl Client { self.send(request, None).await } - pub(crate) async fn process_sync( - &self, - response: sync_events::Response, - ) -> Result { - let response = self.base_client.receive_sync_response(response).await?; - let SyncResponse { - next_batch: _, - rooms, - presence, - account_data, - to_device: _, - device_lists: _, - device_one_time_keys_count: _, - ambiguity_changes: _, - notifications, - } = &response; - - self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?; - self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?; - - for (room_id, room_info) in &rooms.join { - let room = self.get_room(room_id); - if room.is_none() { - error!("Can't call event handler, room {} not found", room_id); - continue; - } - - let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } = - room_info; - - self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?; - self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events) - .await?; - self.handle_sync_state_events(&room, &state.events).await?; - self.handle_sync_timeline_events(&room, &timeline.events).await?; - } - - for (room_id, room_info) in &rooms.leave { - let room = self.get_room(room_id); - if room.is_none() { - error!("Can't call event handler, room {} not found", room_id); - continue; - } - - let LeftRoom { timeline, state, account_data } = room_info; - - self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events) - .await?; - self.handle_sync_state_events(&room, &state.events).await?; - self.handle_sync_timeline_events(&room, &timeline.events).await?; - } - - for (room_id, room_info) in &rooms.invite { - let room = self.get_room(room_id); - if room.is_none() { - error!("Can't call event handler, room {} not found", room_id); - continue; - } - - // FIXME: Destructure room_info - self.handle_sync_events( - EventKind::StrippedState, - &room, - &room_info.invite_state.events, - ) - .await?; - } - - // Construct notification event handler futures - let mut futures = Vec::new(); - for handler in &*self.notification_handlers.read().await { - for (room_id, room_notifications) in notifications { - let room = match self.get_room(room_id) { - Some(room) => room, - None => { - warn!("Can't call notification handler, room {} not found", room_id); - continue; - } - }; - - futures.extend(room_notifications.iter().map(|notification| { - (handler)(notification.clone(), room.clone(), self.clone()) - })); - } - } - - // Run the notification handler futures with the - // `self.notification_handlers` lock no longer being held, in order. - for fut in futures { - fut.await; - } - - Ok(response) - } - - async fn sync_loop_helper( - &self, - sync_settings: &mut crate::config::SyncSettings<'_>, - ) -> Result { - let response = self.sync_once(sync_settings.clone()).await; - - match response { - Ok(r) => { - sync_settings.token = Some(r.next_batch.clone()); - Ok(r) - } - Err(e) => { - error!("Received an invalid response: {}", e); - sleep::new(Duration::from_secs(1)).await; - Err(e) - } - } - } - - async fn delay_sync(last_sync_time: &mut Option) { - let now = Instant::now(); - - // If the last sync happened less than a second ago, sleep for a - // while to not hammer out requests if the server doesn't respect - // the sync timeout. - if let Some(t) = last_sync_time { - if now - *t <= Duration::from_secs(1) { - sleep::new(Duration::from_secs(1)).await; - } - } - - *last_sync_time = Some(now); - } - /// Synchronize the client's state with the latest state on the server. /// /// ## Syncing Events @@ -1874,7 +1834,7 @@ impl Client { /// // Now keep on syncing forever. `sync()` will use the stored sync token /// // from our `sync_once()` call automatically. /// client.sync(SyncSettings::default()).await; - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` /// /// [`sync`]: #method.sync @@ -1901,9 +1861,9 @@ impl Client { timeout: sync_settings.timeout, }); - let request_config = self.http_client.request_config.timeout( + let request_config = self.inner.http_client.request_config.timeout( sync_settings.timeout.unwrap_or_else(|| Duration::from_secs(0)) - + self.http_client.request_config.timeout, + + self.inner.http_client.request_config.timeout, ); let response = self.send(request, Some(request_config)).await?; @@ -1914,7 +1874,7 @@ impl Client { error!(error =? e, "Error while sending outgoing E2EE requests"); }; - self.sync_beat.notify(usize::MAX); + self.inner.sync_beat.notify(usize::MAX); Ok(response) } @@ -1931,7 +1891,7 @@ impl Client { /// method to react to individual events. If you instead wish to handle /// events in a bulk manner the [`Client::sync_with_callback`] and /// [`Client::sync_stream`] methods can be used instead. Those two methods - /// repeadetly return the whole sync response. + /// repeatedly return the whole sync response. /// /// # Arguments /// @@ -1965,7 +1925,7 @@ impl Client { /// // Now keep on syncing forever. `sync()` will use the latest sync token /// // automatically. /// client.sync(SyncSettings::default()).await; - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` /// /// [argument docs]: #method.sync_once @@ -2094,7 +2054,7 @@ impl Client { /// } /// } /// - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` #[instrument] pub async fn sync_stream<'a>( @@ -2119,7 +2079,7 @@ impl Client { /// Get the current, if any, sync token of the client. /// This will be None if the client didn't sync at least once. pub async fn sync_token(&self) -> Option { - self.base_client.sync_token().await + self.inner.base_client.sync_token().await } /// Get a media file's content. @@ -2138,7 +2098,7 @@ impl Client { use_cache: bool, ) -> Result> { let content = if use_cache { - self.base_client.store().get_media_content(request).await? + self.inner.base_client.store().get_media_content(request).await? } else { None }; @@ -2182,7 +2142,7 @@ impl Client { }; if use_cache { - self.base_client.store().add_media_content(request, content.clone()).await?; + self.inner.base_client.store().add_media_content(request, content.clone()).await?; } Ok(content) @@ -2195,7 +2155,7 @@ impl Client { /// /// * `request` - The `MediaRequest` of the content. pub async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { - Ok(self.base_client.store().remove_media_content(request).await?) + Ok(self.inner.base_client.store().remove_media_content(request).await?) } /// Delete all the media content corresponding to the given @@ -2205,7 +2165,7 @@ impl Client { /// /// * `uri` - The `MxcUri` of the files. pub async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { - Ok(self.base_client.store().remove_media_content_for_uri(uri).await?) + Ok(self.inner.base_client.store().remove_media_content_for_uri(uri).await?) } /// Get the file of the given media event content. @@ -2724,7 +2684,7 @@ pub(crate) mod test { .add_state_event(EventsJson::PowerLevels) .build_sync_response(); - client.base_client.receive_sync_response(response).await.unwrap(); + client.inner.base_client.receive_sync_response(response).await.unwrap(); let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost"); assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap()); diff --git a/crates/matrix-sdk/src/config/client.rs b/crates/matrix-sdk/src/config/client.rs index 07772887b..e3c315206 100644 --- a/crates/matrix-sdk/src/config/client.rs +++ b/crates/matrix-sdk/src/config/client.rs @@ -90,7 +90,7 @@ impl ClientConfig { /// let client_config = ClientConfig::new() /// .proxy("http://localhost:8080")?; /// - /// # matrix_sdk::Result::Ok(()) + /// # Result::<_, matrix_sdk::Error>::Ok(()) /// ``` #[cfg(not(target_arch = "wasm32"))] pub fn proxy(mut self, proxy: &str) -> Result { @@ -105,7 +105,7 @@ impl ClientConfig { } /// Set a custom HTTP user agent for the client. - pub fn user_agent(mut self, user_agent: &str) -> std::result::Result { + pub fn user_agent(mut self, user_agent: &str) -> Result { self.user_agent = Some(HeaderValue::from_str(user_agent)?); Ok(self) } diff --git a/crates/matrix-sdk/src/encryption/identities/devices.rs b/crates/matrix-sdk/src/encryption/identities/devices.rs index fab4e11ba..a45a34b93 100644 --- a/crates/matrix-sdk/src/encryption/identities/devices.rs +++ b/crates/matrix-sdk/src/encryption/identities/devices.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ops::Deref, result::Result as StdResult}; +use std::ops::Deref; use matrix_sdk_base::crypto::{ store::CryptoStoreError, Device as BaseDevice, LocalTrust, ReadOnlyDevice, @@ -250,7 +250,7 @@ impl Device { /// } /// # anyhow::Result::<()>::Ok(()) }); /// ``` - pub async fn verify(&self) -> std::result::Result<(), ManualVerifyError> { + pub async fn verify(&self) -> Result<(), ManualVerifyError> { let request = self.inner.verify().await?; self.client.send(request, None).await?; @@ -391,10 +391,7 @@ impl Device { /// # Arguments /// /// * `trust_state` - The new trust state that should be set for the device. - pub async fn set_local_trust( - &self, - trust_state: LocalTrust, - ) -> StdResult<(), CryptoStoreError> { + pub async fn set_local_trust(&self, trust_state: LocalTrust) -> Result<(), CryptoStoreError> { self.inner.set_local_trust(trust_state).await } } diff --git a/crates/matrix-sdk/src/encryption/identities/users.rs b/crates/matrix-sdk/src/encryption/identities/users.rs index 9ae2042e1..4e345b22f 100644 --- a/crates/matrix-sdk/src/encryption/identities/users.rs +++ b/crates/matrix-sdk/src/encryption/identities/users.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{result::Result, sync::Arc}; +use std::sync::Arc; use matrix_sdk_base::{ crypto::{ diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index ad7a38eef..4cbcb0fd3 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -266,6 +266,7 @@ use matrix_sdk_base::{ use matrix_sdk_common::{instant::Duration, uuid::Uuid}; use ruma::{ api::client::r0::{ + backup::add_backup_keys::Response as KeysBackupResponse, keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest}, message::send_message_event, to_device::send_event_to_device::{ @@ -294,7 +295,7 @@ impl Client { /// called the fingerprint of the device. #[cfg(feature = "encryption")] pub async fn ed25519_key(&self) -> Option { - self.base_client.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned()) + self.olm_machine().await.map(|o| o.identity_keys().ed25519().to_owned()) } /// Get the status of the private cross signing keys. @@ -303,7 +304,7 @@ impl Client { /// stored locally. #[cfg(feature = "encryption")] pub async fn cross_signing_status(&self) -> Option { - if let Some(machine) = self.base_client.olm_machine().await { + if let Some(machine) = self.olm_machine().await { Some(machine.cross_signing_status().await) } else { None @@ -316,13 +317,13 @@ impl Client { /// capable devices up to date. #[cfg(feature = "encryption")] pub async fn tracked_users(&self) -> HashSet { - self.base_client.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default() + self.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default() } /// Get a verification object with the given flow id. #[cfg(feature = "encryption")] pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { - let olm = self.base_client.olm_machine().await?; + let olm = self.olm_machine().await?; olm.get_verification(user_id, flow_id).map(|v| match v { matrix_sdk_base::crypto::Verification::SasV1(s) => { SasVerification { inner: s, client: self.clone() }.into() @@ -342,7 +343,7 @@ impl Client { user_id: &UserId, flow_id: impl AsRef, ) -> Option { - let olm = self.base_client.olm_machine().await?; + let olm = self.olm_machine().await?; olm.get_verification_request(user_id, flow_id) .map(|r| VerificationRequest { inner: r, client: self.clone() }) @@ -387,7 +388,7 @@ impl Client { user_id: &UserId, device_id: &DeviceId, ) -> StdResult, CryptoStoreError> { - let device = self.base_client.get_device(user_id, device_id).await?; + let device = self.base_client().get_device(user_id, device_id).await?; Ok(device.map(|d| Device { inner: d, client: self.clone() })) } @@ -424,7 +425,7 @@ impl Client { &self, user_id: &UserId, ) -> StdResult { - let devices = self.base_client.get_user_devices(user_id).await?; + let devices = self.base_client().get_user_devices(user_id).await?; Ok(UserDevices { inner: devices, client: self.clone() }) } @@ -467,7 +468,7 @@ impl Client { ) -> StdResult, CryptoStoreError> { use crate::encryption::identities::UserIdentity; - if let Some(olm) = self.base_client.olm_machine().await { + if let Some(olm) = self.olm_machine().await { let identity = olm.get_identity(user_id).await?; Ok(identity.map(|i| match i { @@ -525,7 +526,7 @@ impl Client { /// # anyhow::Result::<()>::Ok(()) }); #[cfg(feature = "encryption")] pub async fn bootstrap_cross_signing(&self, auth_data: Option>) -> Result<()> { - let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?; + let olm = self.olm_machine().await.ok_or(Error::AuthenticationRequired)?; let (request, signature_request) = olm.bootstrap_cross_signing(false).await?; @@ -600,7 +601,7 @@ impl Client { passphrase: &str, predicate: impl FnMut(&matrix_sdk_base::crypto::olm::InboundGroupSession) -> bool, ) -> Result<()> { - let olm = self.base_client.olm_machine().await.ok_or(Error::AuthenticationRequired)?; + let olm = self.olm_machine().await.ok_or(Error::AuthenticationRequired)?; let keys = olm.export_keys(predicate).await?; let passphrase = zeroize::Zeroizing::new(passphrase.to_owned()); @@ -660,7 +661,7 @@ impl Client { path: PathBuf, passphrase: &str, ) -> StdResult { - let olm = self.base_client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?; + let olm = self.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?; let passphrase = zeroize::Zeroizing::new(passphrase.to_owned()); let decrypt = move || { @@ -683,7 +684,7 @@ impl Client { ) -> serde_json::Result { use ruma::serde::JsonObject; - if let Some(machine) = self.base_client.olm_machine().await { + if let Some(machine) = self.olm_machine().await { if let AnyRoomEvent::Message(event) = event { if let AnyMessageEvent::RoomEncrypted(_) = event { let room_id = event.room_id(); @@ -726,7 +727,7 @@ impl Client { let request = assign!(get_keys::Request::new(), { device_keys }); let response = self.send(request, None).await?; - self.base_client.mark_request_as_sent(request_id, &response).await?; + self.mark_request_as_sent(request_id, &response).await?; Ok(response) } @@ -843,7 +844,7 @@ impl Client { if let Some(room) = self.get_joined_room(&response.room_id) { Ok(Some(room)) } else { - self.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME); + self.inner.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME); Ok(self.get_joined_room(&response.room_id)) } } @@ -859,11 +860,11 @@ impl Client { &self, users: impl Iterator, ) -> Result<()> { - let _lock = self.key_claim_lock.lock().await; + let _lock = self.inner.key_claim_lock.lock().await; - if let Some((request_id, request)) = self.base_client.get_missing_sessions(users).await? { + if let Some((request_id, request)) = self.base_client().get_missing_sessions(users).await? { let response = self.send(request, None).await?; - self.base_client.mark_request_as_sent(&request_id, &response).await?; + self.mark_request_as_sent(&request_id, &response).await?; } Ok(()) @@ -892,7 +893,7 @@ impl Client { ); let response = self.send(request.clone(), None).await?; - self.base_client.mark_request_as_sent(request_id, &response).await?; + self.mark_request_as_sent(request_id, &response).await?; Ok(response) } @@ -970,25 +971,41 @@ impl Client { } OutgoingRequests::ToDeviceRequest(request) => { let response = self.send_to_device(request).await?; - self.base_client.mark_request_as_sent(r.request_id(), &response).await?; + self.mark_request_as_sent(r.request_id(), &response).await?; } OutgoingRequests::SignatureUpload(request) => { let response = self.send(request.clone(), None).await?; - self.base_client.mark_request_as_sent(r.request_id(), &response).await?; + self.mark_request_as_sent(r.request_id(), &response).await?; } OutgoingRequests::RoomMessage(request) => { let response = self.room_send_helper(request).await?; - self.base_client.mark_request_as_sent(r.request_id(), &response).await?; + self.mark_request_as_sent(r.request_id(), &response).await?; } OutgoingRequests::KeysClaim(request) => { let response = self.send(request.clone(), None).await?; - self.base_client.mark_request_as_sent(r.request_id(), &response).await?; + self.mark_request_as_sent(r.request_id(), &response).await?; + } + OutgoingRequests::KeysBackup(request) => { + let response = self.send_backup_request(request).await?; + self.mark_request_as_sent(r.request_id(), &response).await?; } } Ok(()) } + async fn send_backup_request( + &self, + request: &matrix_sdk_base::crypto::KeysBackupRequest, + ) -> Result { + let request = ruma::api::client::r0::backup::add_backup_keys::Request::new( + &request.version, + request.rooms.to_owned(), + ); + + Ok(self.send(request, None).await?) + } + pub(crate) async fn send_outgoing_requests(&self) -> Result<()> { const MAX_CONCURRENT_REQUESTS: usize = 20; @@ -998,7 +1015,7 @@ impl Client { warn!("Error while claiming one-time keys {:?}", e); } - let outgoing_requests = stream::iter(self.base_client.outgoing_requests().await?) + let outgoing_requests = stream::iter(self.base_client().outgoing_requests().await?) .map(|r| self.send_outgoing_request(r)); let requests = outgoing_requests.buffer_unordered(MAX_CONCURRENT_REQUESTS); diff --git a/crates/matrix-sdk/src/encryption/verification/qrcode.rs b/crates/matrix-sdk/src/encryption/verification/qrcode.rs index 78b80e807..87322ff3d 100644 --- a/crates/matrix-sdk/src/encryption/verification/qrcode.rs +++ b/crates/matrix-sdk/src/encryption/verification/qrcode.rs @@ -71,7 +71,7 @@ impl QrVerification { /// /// The [`to_bytes()`](#method.to_bytes) method can be used to instead /// output the raw bytes that should be encoded as a QR code. - pub fn to_qr_code(&self) -> std::result::Result { + pub fn to_qr_code(&self) -> Result { self.inner.to_qr_code() } @@ -80,7 +80,7 @@ impl QrVerification { /// /// The [`to_qr_code()`](#method.to_qr_code) method can be used to instead /// output a `QrCode` object that can be rendered. - pub fn to_bytes(&self) -> std::result::Result, EncodingError> { + pub fn to_bytes(&self) -> Result, EncodingError> { self.inner.to_bytes() } diff --git a/crates/matrix-sdk/src/error.rs b/crates/matrix-sdk/src/error.rs index 33ab68f0c..cd02986a3 100644 --- a/crates/matrix-sdk/src/error.rs +++ b/crates/matrix-sdk/src/error.rs @@ -40,7 +40,7 @@ use thiserror::Error; use url::ParseError as UrlParseError; /// Result type of the matrix-sdk. -pub type Result = std::result::Result; +pub type Result = std::result::Result; /// Result type of a pure HTTP request. pub type HttpResult = std::result::Result; diff --git a/crates/matrix-sdk/src/event_handler.rs b/crates/matrix-sdk/src/event_handler.rs index ee48a053a..97993eed7 100644 --- a/crates/matrix-sdk/src/event_handler.rs +++ b/crates/matrix-sdk/src/event_handler.rs @@ -186,8 +186,7 @@ pub struct Ctx(pub T); impl EventHandlerContext for Ctx { fn from_data(data: &EventHandlerData<'_>) -> Option { - let anymap = data.client.event_handler_data.read().unwrap(); - Some(Ctx(anymap.get::()?.clone())) + data.client.event_handler_context::().map(Ctx) } } @@ -330,8 +329,7 @@ impl Client { // Construct event handler futures let futures: Vec<_> = self - .event_handlers - .read() + .event_handlers() .await .get(&event_handler_id) .into_iter() diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index a9863acd8..2adca293f 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -53,8 +53,8 @@ pub mod event_handler; mod http_client; /// High-level room API pub mod room; -/// High-level room API mod room_member; +mod sync; #[cfg(feature = "encryption")] pub mod encryption; diff --git a/crates/matrix-sdk/src/room/common.rs b/crates/matrix-sdk/src/room/common.rs index 84559b67b..9ad325634 100644 --- a/crates/matrix-sdk/src/room/common.rs +++ b/crates/matrix-sdk/src/room/common.rs @@ -168,14 +168,17 @@ impl Common { pub(crate) async fn request_members(&self) -> Result> { if let Some(mutex) = - self.client.members_request_locks.get(self.inner.room_id()).map(|m| m.clone()) + self.client.inner.members_request_locks.get(self.inner.room_id()).map(|m| m.clone()) { mutex.lock().await; Ok(None) } else { let mutex = Arc::new(Mutex::new(())); - self.client.members_request_locks.insert(self.inner.room_id().clone(), mutex.clone()); + self.client + .inner + .members_request_locks + .insert(self.inner.room_id().clone(), mutex.clone()); let _guard = mutex.lock().await; @@ -183,9 +186,9 @@ impl Common { let response = self.client.send(request, None).await?; let response = - self.client.base_client.receive_members(self.inner.room_id(), &response).await?; + self.client.base_client().receive_members(self.inner.room_id(), &response).await?; - self.client.members_request_locks.remove(self.inner.room_id()); + self.client.inner.members_request_locks.remove(self.inner.room_id()); Ok(Some(response)) } diff --git a/crates/matrix-sdk/src/room/joined.rs b/crates/matrix-sdk/src/room/joined.rs index 36af5b445..3425fe28e 100644 --- a/crates/matrix-sdk/src/room/joined.rs +++ b/crates/matrix-sdk/src/room/joined.rs @@ -170,37 +170,39 @@ impl Joined { /// if let Some(room) = client.get_joined_room(&room_id) { /// room.typing_notice(true).await? /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn typing_notice(&self, typing: bool) -> Result<()> { // Only send a request to the homeserver if the old timeout has elapsed // or the typing notice changed state within the // TYPING_NOTICE_TIMEOUT - let send = - if let Some(typing_time) = self.client.typing_notice_times.get(self.inner.room_id()) { - if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT { - // We always reactivate the typing notice if typing is true or - // we may need to deactivate it if it's - // currently active if typing is false - typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT - } else { - // Only send a request when we need to deactivate typing - !typing - } + let send = if let Some(typing_time) = + self.client.inner.typing_notice_times.get(self.inner.room_id()) + { + if typing_time.elapsed() > TYPING_NOTICE_RESEND_TIMEOUT { + // We always reactivate the typing notice if typing is true or + // we may need to deactivate it if it's + // currently active if typing is false + typing || typing_time.elapsed() <= TYPING_NOTICE_TIMEOUT } else { - // Typing notice is currently deactivated, therefore, send a request - // only when it's about to be activated - typing - }; + // Only send a request when we need to deactivate typing + !typing + } + } else { + // Typing notice is currently deactivated, therefore, send a request + // only when it's about to be activated + typing + }; if send { let typing = if typing { self.client + .inner .typing_notice_times .insert(self.inner.room_id().clone(), Instant::now()); Typing::Yes(TYPING_NOTICE_TIMEOUT) } else { - self.client.typing_notice_times.remove(self.inner.room_id()); + self.client.inner.typing_notice_times.remove(self.inner.room_id()); Typing::No }; @@ -278,7 +280,7 @@ impl Joined { /// if let Some(room) = client.get_joined_room(&room_id) { /// room.enable_encryption().await? /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn enable_encryption(&self) -> Result<()> { use ruma::{ @@ -294,7 +296,7 @@ impl Joined { // TODO do we want to return an error here if we time out? This // could be quite useful if someone wants to enable encryption and // send a message right after it's enabled. - self.client.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME); + self.client.inner.sync_beat.listen().wait_timeout(SYNC_WAIT_TIME); } Ok(()) @@ -311,7 +313,7 @@ impl Joined { // TODO expose this publicly so people can pre-share a group session if // e.g. a user starts to type a message for a room. if let Some(mutex) = - self.client.group_session_locks.get(self.inner.room_id()).map(|m| m.clone()) + self.client.inner.group_session_locks.get(self.inner.room_id()).map(|m| m.clone()) { // If a group session share request is already going on, // await the release of the lock. @@ -320,7 +322,10 @@ impl Joined { // Otherwise create a new lock and share the group // session. let mutex = Arc::new(Mutex::new(())); - self.client.group_session_locks.insert(self.inner.room_id().clone(), mutex.clone()); + self.client + .inner + .group_session_locks + .insert(self.inner.room_id().clone(), mutex.clone()); let _guard = mutex.lock().await; @@ -334,13 +339,13 @@ impl Joined { let response = self.share_group_session().await; - self.client.group_session_locks.remove(self.inner.room_id()); + self.client.inner.group_session_locks.remove(self.inner.room_id()); // If one of the responses failed invalidate the group // session as using it would end up in undecryptable // messages. if let Err(r) = response { - self.client.base_client.invalidate_group_session(self.inner.room_id()).await?; + self.client.base_client().invalidate_group_session(self.inner.room_id()).await?; return Err(r); } } @@ -357,12 +362,12 @@ impl Joined { #[cfg(feature = "encryption")] async fn share_group_session(&self) -> Result<()> { let mut requests = - self.client.base_client.share_group_session(self.inner.room_id()).await?; + self.client.base_client().share_group_session(self.inner.room_id()).await?; for request in requests.drain(..) { let response = self.client.send_to_device(&request).await?; - self.client.base_client.mark_request_as_sent(&request.txn_id, &response).await?; + self.client.mark_request_as_sent(&request.txn_id, &response).await?; } Ok(()) @@ -449,7 +454,7 @@ impl Joined { /// if let Some(room) = client.get_joined_room(&room_id) { /// room.send(content, Some(txn_id)).await?; /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` /// /// [`SyncMessageEvent`]: ruma::events::SyncMessageEvent @@ -517,7 +522,7 @@ impl Joined { /// if let Some(room) = client.get_joined_room(&room_id) { /// room.send_raw(content, "m.room.message", None).await?; /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` /// /// [`SyncMessageEvent`]: ruma::events::SyncMessageEvent @@ -554,8 +559,7 @@ impl Joined { self.preshare_group_session().await?; - let olm = - self.client.base_client.olm_machine().await.expect("Olm machine wasn't started"); + let olm = self.client.olm_machine().await.expect("Olm machine wasn't started"); let encrypted_content = olm.encrypt_raw(self.inner.room_id(), content, event_type).await?; @@ -631,7 +635,7 @@ impl Joined { /// None, /// ).await?; /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn send_attachment( &self, @@ -698,7 +702,7 @@ impl Joined { /// if let Some(room) = client.get_joined_room(&room_id) { /// room.send_state_event(content, "").await?; /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn send_state_event( &self, @@ -791,7 +795,7 @@ impl Joined { /// let reason = Some("Indecent material"); /// room.redact(&event_id, reason, None).await?; /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn redact( &self, @@ -834,7 +838,7 @@ impl Joined { /// /// room.set_tag("u.work", tag_info ).await?; /// } - /// # matrix_sdk::Result::Ok(()) }); + /// # Result::<_, matrix_sdk::Error>::Ok(()) }); /// ``` pub async fn set_tag(&self, tag: &str, tag_info: TagInfo) -> HttpResult { let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?; diff --git a/crates/matrix-sdk/src/sync.rs b/crates/matrix-sdk/src/sync.rs new file mode 100644 index 000000000..398fbd914 --- /dev/null +++ b/crates/matrix-sdk/src/sync.rs @@ -0,0 +1,144 @@ +use std::time::Duration; + +use futures_timer::Delay as sleep; +use matrix_sdk_base::{ + deserialized_responses::{JoinedRoom, LeftRoom, SyncResponse}, + instant::Instant, +}; +use ruma::api::client::r0::sync::sync_events; +use tracing::{error, warn}; + +use crate::{event_handler::EventKind, Client, Result}; + +/// Internal functionality related to getting events from the server +/// (`sync_events` endpoint) +impl Client { + pub(crate) async fn process_sync( + &self, + response: sync_events::Response, + ) -> Result { + let response = self.base_client().receive_sync_response(response).await?; + let SyncResponse { + next_batch: _, + rooms, + presence, + account_data, + to_device: _, + device_lists: _, + device_one_time_keys_count: _, + ambiguity_changes: _, + notifications, + } = &response; + + self.handle_sync_events(EventKind::GlobalAccountData, &None, &account_data.events).await?; + self.handle_sync_events(EventKind::Presence, &None, &presence.events).await?; + + for (room_id, room_info) in &rooms.join { + let room = self.get_room(room_id); + if room.is_none() { + error!("Can't call event handler, room {} not found", room_id); + continue; + } + + let JoinedRoom { unread_notifications: _, timeline, state, account_data, ephemeral } = + room_info; + + self.handle_sync_events(EventKind::EphemeralRoomData, &room, &ephemeral.events).await?; + self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events) + .await?; + self.handle_sync_state_events(&room, &state.events).await?; + self.handle_sync_timeline_events(&room, &timeline.events).await?; + } + + for (room_id, room_info) in &rooms.leave { + let room = self.get_room(room_id); + if room.is_none() { + error!("Can't call event handler, room {} not found", room_id); + continue; + } + + let LeftRoom { timeline, state, account_data } = room_info; + + self.handle_sync_events(EventKind::RoomAccountData, &room, &account_data.events) + .await?; + self.handle_sync_state_events(&room, &state.events).await?; + self.handle_sync_timeline_events(&room, &timeline.events).await?; + } + + for (room_id, room_info) in &rooms.invite { + let room = self.get_room(room_id); + if room.is_none() { + error!("Can't call event handler, room {} not found", room_id); + continue; + } + + // FIXME: Destructure room_info + self.handle_sync_events( + EventKind::StrippedState, + &room, + &room_info.invite_state.events, + ) + .await?; + } + + // Construct notification event handler futures + let mut futures = Vec::new(); + for handler in &*self.notification_handlers().await { + for (room_id, room_notifications) in notifications { + let room = match self.get_room(room_id) { + Some(room) => room, + None => { + warn!("Can't call notification handler, room {} not found", room_id); + continue; + } + }; + + futures.extend(room_notifications.iter().map(|notification| { + (handler)(notification.clone(), room.clone(), self.clone()) + })); + } + } + + // Run the notification handler futures with the + // `self.notification_handlers` lock no longer being held, in order. + for fut in futures { + fut.await; + } + + Ok(response) + } + + pub(crate) async fn sync_loop_helper( + &self, + sync_settings: &mut crate::config::SyncSettings<'_>, + ) -> Result { + let response = self.sync_once(sync_settings.clone()).await; + + match response { + Ok(r) => { + sync_settings.token = Some(r.next_batch.clone()); + Ok(r) + } + Err(e) => { + error!("Received an invalid response: {}", e); + sleep::new(Duration::from_secs(1)).await; + Err(e) + } + } + } + + pub(crate) async fn delay_sync(last_sync_time: &mut Option) { + let now = Instant::now(); + + // If the last sync happened less than a second ago, sleep for a + // while to not hammer out requests if the server doesn't respect + // the sync timeout. + if let Some(t) = last_sync_time { + if now - *t <= Duration::from_secs(1) { + sleep::new(Duration::from_secs(1)).await; + } + } + + *last_sync_time = Some(now); + } +}