diff --git a/matrix_sdk/Cargo.toml b/matrix_sdk/Cargo.toml index cc1ba5807..608f6c57d 100644 --- a/matrix_sdk/Cargo.toml +++ b/matrix_sdk/Cargo.toml @@ -31,7 +31,7 @@ docs = ["encryption", "sqlite_cryptostore", "messages"] async-trait = "0.1.41" dashmap = { version = "3.11.10", optional = true } http = "0.2.1" -serde_json = "1.0.58" +serde_json = "1.0.59" thiserror = "1.0.21" tracing = "0.1.21" url = "2.1.1" @@ -73,7 +73,7 @@ async-std = { version = "1.6.5", features = ["unstable"] } dirs = "3.0.1" matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } -serde_json = "1.0.58" +serde_json = "1.0.59" tracing-subscriber = "0.2.13" tempfile = "3.1.0" mockito = "0.27.0" diff --git a/matrix_sdk/examples/cross_signing_bootstrap.rs b/matrix_sdk/examples/cross_signing_bootstrap.rs new file mode 100644 index 000000000..54def1cb3 --- /dev/null +++ b/matrix_sdk/examples/cross_signing_bootstrap.rs @@ -0,0 +1,119 @@ +use std::{ + collections::BTreeMap, + env, io, + process::exit, + sync::atomic::{AtomicBool, Ordering}, +}; + +use serde_json::json; +use url::Url; + +use matrix_sdk::{ + self, + api::r0::uiaa::AuthData, + identifiers::{user_id, UserId}, + Client, ClientConfig, LoopCtrl, SyncSettings, +}; + +fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> { + let mut auth_parameters = BTreeMap::new(); + let identifier = json!({ + "type": "m.id.user", + "user": user, + }); + + auth_parameters.insert("identifier".to_owned(), identifier); + auth_parameters.insert("password".to_owned(), password.to_owned().into()); + + // This is needed because of https://github.com/matrix-org/synapse/issues/5665 + auth_parameters.insert("user".to_owned(), user.as_str().into()); + + AuthData::DirectRequest { + kind: "m.login.password", + auth_parameters, + session, + } +} + +async fn bootstrap(client: Client) { + println!("Bootstrapping a new cross signing identity, press enter to continue."); + + let mut input = String::new(); + + io::stdin() + .read_line(&mut input) + .expect("error: unable to read user input"); + + if let Err(e) = client.bootstrap_cross_signing(None).await { + if let Some(response) = e.uiaa_response() { + let auth_data = auth_data( + &user_id!("@example:localhost"), + "wordpass", + response.session.as_deref(), + ); + client + .bootstrap_cross_signing(Some(auth_data)) + .await + .expect("Couldn't bootstrap cross signing") + } else { + panic!("Error durign cross signing bootstrap {:#?}", e); + } + } +} + +async fn login( + homeserver_url: String, + username: &str, + password: &str, +) -> Result<(), matrix_sdk::Error> { + let client_config = ClientConfig::new() + .disable_ssl_verification() + .proxy("http://localhost:8080") + .unwrap(); + let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); + let client = Client::new_with_config(homeserver_url, client_config).unwrap(); + + client + .login(username, password, None, Some("rust-sdk")) + .await?; + + let client_ref = &client; + let asked = AtomicBool::new(false); + let asked_ref = &asked; + + client + .sync_with_callback(SyncSettings::new(), |_| async move { + let asked = asked_ref; + let client = &client_ref; + + // Wait for sync to be done then ask the user to bootstrap. + if !asked.load(Ordering::SeqCst) { + tokio::spawn(bootstrap((*client).clone())); + } + + asked.store(true, Ordering::SeqCst); + LoopCtrl::Continue + }) + .await; + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), matrix_sdk::Error> { + tracing_subscriber::fmt::init(); + + let (homeserver_url, username, password) = + match (env::args().nth(1), env::args().nth(2), env::args().nth(3)) { + (Some(a), Some(b), Some(c)) => (a, b, c), + _ => { + eprintln!( + "Usage: {} ", + env::args().next().unwrap() + ); + exit(1) + } + }; + + login(homeserver_url, &username, &password).await +} diff --git a/matrix_sdk/src/client.rs b/matrix_sdk/src/client.rs index c4cf808fe..324025e0b 100644 --- a/matrix_sdk/src/client.rs +++ b/matrix_sdk/src/client.rs @@ -65,7 +65,7 @@ pub enum LoopCtrl { use matrix_sdk_common::{ api::r0::{ account::register, - device::get_devices, + device::{delete_devices, get_devices}, directory::{get_public_rooms, get_public_rooms_filtered}, media::create_content, membership::{ @@ -82,6 +82,7 @@ use matrix_sdk_common::{ typing::create_typing_event::{ Request as TypingRequest, Response as TypingResponse, Typing, }, + uiaa::AuthData, }, assign, events::{ @@ -94,7 +95,7 @@ use matrix_sdk_common::{ }, AnyMessageEventContent, }, - identifiers::{EventId, RoomId, RoomIdOrAliasId, ServerName, UserId}, + identifiers::{DeviceIdBox, EventId, RoomId, RoomIdOrAliasId, ServerName, UserId}, instant::{Duration, Instant}, js_int::UInt, locks::RwLock, @@ -106,12 +107,11 @@ use matrix_sdk_common::{ #[cfg(feature = "encryption")] use matrix_sdk_common::{ api::r0::{ - keys::{get_keys, upload_keys}, + keys::{get_keys, upload_keys, upload_signing_keys::Request as UploadSigningKeysRequest}, to_device::send_event_to_device::{ Request as RumaToDeviceRequest, Response as ToDeviceResponse, }, }, - identifiers::DeviceIdBox, locks::Mutex, }; @@ -1369,6 +1369,71 @@ impl Client { self.send(request).await } + /// Delete the given devices from the server. + /// + /// # Arguments + /// + /// * `devices` - The list of devices that should be deleted from the + /// server. + /// + /// * `auth_data` - This request requires user interactive auth, the first + /// request needs to set this to `None` and will always fail with an + /// `UiaaResponse`. The response will contain information for the + /// interactive auth and the same request needs to be made but this time + /// with some `auth_data` provided. + /// + /// ```no_run + /// # use matrix_sdk::{ + /// # api::r0::uiaa::{UiaaResponse, AuthData}, + /// # Client, SyncSettings, Error, FromHttpResponseError, ServerError, + /// # }; + /// # use futures::executor::block_on; + /// # use serde_json::json; + /// # use url::Url; + /// # use std::{collections::BTreeMap, convert::TryFrom}; + /// # block_on(async { + /// # let homeserver = Url::parse("http://localhost:8080").unwrap(); + /// # let mut client = Client::new(homeserver).unwrap(); + /// let devices = &["DEVICEID".into()]; + /// + /// if let Err(e) = client.delete_devices(devices, None).await { + /// if let Some(info) = e.uiaa_response() { + /// let mut auth_parameters = BTreeMap::new(); + /// + /// let identifier = json!({ + /// "type": "m.id.user", + /// "user": "example", + /// }); + /// auth_parameters.insert("identifier".to_owned(), identifier); + /// auth_parameters.insert("password".to_owned(), "wordpass".into()); + /// + /// // This is needed because of https://github.com/matrix-org/synapse/issues/5665 + /// auth_parameters.insert("user".to_owned(), "@example:localhost".into()); + /// + /// let auth_data = AuthData::DirectRequest { + /// kind: "m.login.password", + /// auth_parameters, + /// session: info.session.as_deref(), + /// }; + /// + /// client + /// .delete_devices(devices, Some(auth_data)) + /// .await + /// .expect("Can't delete devices"); + /// } + /// } + /// # }); + pub async fn delete_devices( + &self, + devices: &[DeviceIdBox], + auth_data: Option>, + ) -> Result { + let mut request = delete_devices::Request::new(devices); + request.auth = auth_data; + + self.send(request).await + } + /// Synchronize the client's state with the latest state on the server. /// /// **Note**: You should not use this method to repeatedly sync if encryption @@ -1742,6 +1807,33 @@ impl Client { })) } + /// TODO + #[cfg(feature = "encryption")] + #[cfg_attr(feature = "docs", doc(cfg(encryption)))] + pub async fn bootstrap_cross_signing(&self, auth_data: Option>) -> Result<()> { + let olm = self + .base_client + .olm_machine() + .await + .ok_or(Error::AuthenticationRequired)?; + + let (request, signature_request) = olm.bootstrap_cross_signing(false).await?; + + println!("HELLOOO MAKING REQUEST {:#?}", request); + + let request = UploadSigningKeysRequest { + auth: auth_data, + master_key: request.master_key, + self_signing_key: request.self_signing_key, + user_signing_key: request.user_signing_key, + }; + + self.send(request).await?; + self.send(signature_request).await?; + + Ok(()) + } + /// Get a map holding all the devices of an user. /// /// This will always return an empty map if the client hasn't been logged @@ -1800,6 +1892,8 @@ impl Client { /// /// # Panics /// + /// This method will panic if it isn't run on a Tokio runtime. + /// /// This method will panic if it can't get enough randomness from the OS to /// encrypt the exported keys securely. /// @@ -1868,6 +1962,17 @@ impl Client { /// Import E2EE keys from the given file path. /// + /// # Arguments + /// + /// * `path` - The file path where the exported key file will can be found. + /// + /// * `passphrase` - The passphrase that should be used to decrypt the + /// exported room keys. + /// + /// # Panics + /// + /// This method will panic if it isn't run on a Tokio runtime. + /// /// ```no_run /// # use std::{path::PathBuf, time::Duration}; /// # use matrix_sdk::{ @@ -1893,7 +1998,7 @@ impl Client { feature = "docs", doc(cfg(all(encryption, not(target_arch = "wasm32")))) )] - pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result<()> { + pub async fn import_keys(&self, path: PathBuf, passphrase: &str) -> Result { let olm = self .base_client .olm_machine() @@ -1909,9 +2014,7 @@ impl Client { let task = tokio::task::spawn_blocking(decrypt); let import = task.await.expect("Task join error").unwrap(); - olm.import_keys(import).await.unwrap(); - - Ok(()) + Ok(olm.import_keys(import).await?) } } @@ -1941,7 +2044,10 @@ mod test { use serde_json::json; use tempfile::tempdir; - use std::{convert::TryInto, io::Cursor, path::Path, str::FromStr, time::Duration}; + use std::{ + collections::BTreeMap, convert::TryInto, io::Cursor, path::Path, str::FromStr, + time::Duration, + }; async fn logged_in_client() -> Client { let session = Session { @@ -2740,4 +2846,61 @@ mod test { assert_eq!("tutorial".to_string(), room.read().await.display_name()); } + + #[tokio::test] + async fn delete_devices() { + let homeserver = Url::from_str(&mockito::server_url()).unwrap(); + let client = Client::new(homeserver).unwrap(); + + let _m = mock("POST", "/_matrix/client/r0/delete_devices") + .with_status(401) + .with_body( + json!({ + "flows": [ + { + "stages": [ + "m.login.password" + ] + } + ], + "params": {}, + "session": "vBslorikviAjxzYBASOBGfPp" + }) + .to_string(), + ) + .create(); + + let _m = mock("POST", "/_matrix/client/r0/delete_devices") + .with_status(401) + // empty response + // TODO rename that response type. + .with_body(test_json::LOGOUT.to_string()) + .create(); + + let devices = &["DEVICEID".into()]; + + if let Err(e) = client.delete_devices(devices, None).await { + if let Some(info) = e.uiaa_response() { + let mut auth_parameters = BTreeMap::new(); + + let identifier = json!({ + "type": "m.id.user", + "user": "example", + }); + auth_parameters.insert("identifier".to_owned(), identifier); + auth_parameters.insert("password".to_owned(), "wordpass".into()); + + let auth_data = AuthData::DirectRequest { + kind: "m.login.password", + auth_parameters, + session: info.session.as_deref(), + }; + + client + .delete_devices(devices, Some(auth_data)) + .await + .unwrap(); + } + } + } } diff --git a/matrix_sdk/src/device.rs b/matrix_sdk/src/device.rs index 5c7b55fba..776b9f0ea 100644 --- a/matrix_sdk/src/device.rs +++ b/matrix_sdk/src/device.rs @@ -19,7 +19,8 @@ use matrix_sdk_base::crypto::{ UserDevices as BaseUserDevices, }; use matrix_sdk_common::{ - api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, identifiers::DeviceId, + api::r0::to_device::send_event_to_device::Request as ToDeviceRequest, + identifiers::{DeviceId, DeviceIdBox}, }; use crate::{error::Result, http_client::HttpClient, Sas}; @@ -114,7 +115,7 @@ impl UserDevices { } /// Iterator over all the device ids of the user devices. - pub fn keys(&self) -> impl Iterator { + pub fn keys(&self) -> impl Iterator { self.inner.keys() } diff --git a/matrix_sdk/src/error.rs b/matrix_sdk/src/error.rs index bdca99121..eda93c962 100644 --- a/matrix_sdk/src/error.rs +++ b/matrix_sdk/src/error.rs @@ -16,8 +16,11 @@ use matrix_sdk_base::Error as MatrixError; use matrix_sdk_common::{ - api::{r0::uiaa::UiaaResponse as UiaaError, Error as RumaClientError}, - FromHttpResponseError as RumaResponseError, IntoHttpError as RumaIntoHttpError, + api::{ + r0::uiaa::{UiaaInfo, UiaaResponse as UiaaError}, + Error as RumaClientError, + }, + FromHttpResponseError as RumaResponseError, IntoHttpError as RumaIntoHttpError, ServerError, }; use reqwest::Error as ReqwestError; use serde_json::Error as JsonError; @@ -79,6 +82,30 @@ pub enum Error { UiaaError(RumaResponseError), } +impl Error { + /// Try to destructure the error into an universal interactive auth info. + /// + /// Some requests require universal interactive auth, doing such a request + /// will always fail the first time with a 401 status code, the response + /// body will contain info how the client can authenticate. + /// + /// The request will need to be retried, this time containing additional + /// authentication data. + /// + /// This method is an convenience method to get to the info the server + /// returned on the first, failed request. + pub fn uiaa_response(&self) -> Option<&UiaaInfo> { + if let Error::UiaaError(RumaResponseError::Http(ServerError::Known( + UiaaError::AuthResponse(i), + ))) = self + { + Some(i) + } else { + None + } + } +} + impl From> for Error { fn from(error: RumaResponseError) -> Self { Self::UiaaError(error) diff --git a/matrix_sdk_base/Cargo.toml b/matrix_sdk_base/Cargo.toml index 99033cf69..a42d87256 100644 --- a/matrix_sdk_base/Cargo.toml +++ b/matrix_sdk_base/Cargo.toml @@ -27,7 +27,7 @@ docs = ["encryption", "sqlite_cryptostore", "messages"] async-trait = "0.1.41" serde = "1.0.116" dashmap= "*" -serde_json = "1.0.58" +serde_json = "1.0.59" zeroize = "1.1.1" tracing = "0.1.21" diff --git a/matrix_sdk_common/Cargo.toml b/matrix_sdk_common/Cargo.toml index 7302ebfce..cf1a8c272 100644 --- a/matrix_sdk_common/Cargo.toml +++ b/matrix_sdk_common/Cargo.toml @@ -20,7 +20,7 @@ js_int = "0.1.9" [dependencies.ruma] version = "0.0.1" -git = "https://github.com/ruma/ruma" +path = "/home/poljar/werk/priv/ruma/ruma" rev = "50eb700571480d1440e15a387d10f98be8abab59" features = ["client-api", "unstable-pre-spec", "unstable-exhaustive-types"] diff --git a/matrix_sdk_crypto/Cargo.toml b/matrix_sdk_crypto/Cargo.toml index 023c9cb3e..4e7ee1576 100644 --- a/matrix_sdk_crypto/Cargo.toml +++ b/matrix_sdk_crypto/Cargo.toml @@ -28,7 +28,7 @@ matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } olm-rs = { version = "1.0.0", features = ["serde"] } getrandom = "0.2.0" serde = { version = "1.0.116", features = ["derive", "rc"] } -serde_json = "1.0.58" +serde_json = "1.0.59" cjson = "0.1.1" zeroize = { version = "1.1.1", features = ["zeroize_derive"] } url = "2.1.1" @@ -39,6 +39,7 @@ tracing = "0.1.21" atomic = "0.5.0" dashmap = "3.11.10" sha2 = "0.9.1" +aes-gcm = "0.7.0" aes-ctr = "0.5.0" pbkdf2 = { version = "0.5.0", default-features = false } hmac = "0.9.0" @@ -51,7 +52,8 @@ default-features = false features = ["std", "std-future"] [target.'cfg(not(target_arch = "wasm32"))'.dependencies.sqlx] -version = "0.3.5" +git = "https://github.com/launchbadge/sqlx/" +rev = "fd25a7530cf087e1529553ff854f192738db3461" optional = true default-features = false features = ["runtime-tokio", "sqlite", "macros"] @@ -60,7 +62,7 @@ features = ["runtime-tokio", "sqlite", "macros"] tokio = { version = "0.2.22", features = ["rt-threaded", "macros"] } futures = "0.3.6" proptest = "0.10.1" -serde_json = "1.0.58" +serde_json = "1.0.59" tempfile = "3.1.0" http = "0.2.1" matrix-sdk-test = { version = "0.1.0", path = "../matrix_sdk_test" } diff --git a/matrix_sdk_crypto/src/error.rs b/matrix_sdk_crypto/src/error.rs index ee4e5f1f9..1692bd546 100644 --- a/matrix_sdk_crypto/src/error.rs +++ b/matrix_sdk_crypto/src/error.rs @@ -48,8 +48,8 @@ pub enum OlmError { Store(#[from] CryptoStoreError), /// The session with a device has become corrupted. - #[error("decryption failed likely because a Olm session was wedged")] - SessionWedged, + #[error("decryption failed likely because an Olm from {0} with sender key {1} was wedged")] + SessionWedged(UserId, String), /// Encryption failed because the device does not have a valid Olm session /// with us. @@ -148,6 +148,9 @@ pub enum SignatureError { #[error("the provided JSON object can't be converted to a canonical representation")] CanonicalJsonError(CjsonError), + #[error(transparent)] + JsonError(#[from] SerdeError), + #[error("the signature didn't match the provided key")] VerificationError, } diff --git a/matrix_sdk_crypto/src/identities/device.rs b/matrix_sdk_crypto/src/identities/device.rs index 73aeb22d8..0a75ad292 100644 --- a/matrix_sdk_crypto/src/identities/device.rs +++ b/matrix_sdk_crypto/src/identities/device.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::BTreeMap, + collections::{BTreeMap, HashMap}, convert::{TryFrom, TryInto}, ops::Deref, sync::{ @@ -30,13 +30,19 @@ use matrix_sdk_common::{ forwarded_room_key::ForwardedRoomKeyEventContent, room::encrypted::EncryptedEventContent, EventType, }, - identifiers::{DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId}, + identifiers::{ + DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, UserId, + }, + locks::Mutex, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tracing::warn; -use crate::olm::InboundGroupSession; +use crate::{ + olm::{InboundGroupSession, Session}, + store::{Changes, DeviceChanges}, +}; #[cfg(test)] use crate::{OlmMachine, ReadOnlyAccount}; @@ -44,7 +50,7 @@ use crate::{ error::{EventError, OlmError, OlmResult, SignatureError}, identities::{OwnUserIdentity, UserIdentities}, olm::Utility, - store::{caches::ReadOnlyUserDevices, CryptoStore, Result as StoreResult}, + store::{CryptoStore, Result as StoreResult}, verification::VerificationMachine, Sas, ToDeviceRequest, }; @@ -89,6 +95,15 @@ impl Device { .await } + /// Get the Olm sessions that belong to this device. + pub(crate) async fn get_sessions(&self) -> StoreResult>>>> { + if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { + self.verification_machine.store.get_sessions(k).await + } else { + Ok(None) + } + } + /// Get the trust state of the device. pub fn trust_state(&self) -> bool { self.inner @@ -106,10 +121,15 @@ impl Device { pub async fn set_local_trust(&self, trust_state: LocalTrust) -> StoreResult<()> { self.inner.set_trust_state(trust_state); - self.verification_machine - .store - .save_devices(&[self.inner.clone()]) - .await + let changes = Changes { + devices: DeviceChanges { + changed: vec![self.inner.clone()], + ..Default::default() + }, + ..Default::default() + }; + + self.verification_machine.store.save_changes(changes).await } /// Encrypt the given content for this `Device`. @@ -123,7 +143,7 @@ impl Device { &self, event_type: EventType, content: Value, - ) -> OlmResult { + ) -> OlmResult<(Session, EncryptedEventContent)> { self.inner .encrypt(&**self.verification_machine.store, event_type, content) .await @@ -134,7 +154,7 @@ impl Device { pub async fn encrypt_session( &self, session: InboundGroupSession, - ) -> OlmResult { + ) -> OlmResult<(Session, EncryptedEventContent)> { let export = session.export().await; let content: ForwardedRoomKeyEventContent = if let Ok(c) = export.try_into() { @@ -158,7 +178,7 @@ impl Device { /// A read only view over all devices belonging to a user. #[derive(Debug)] pub struct UserDevices { - pub(crate) inner: ReadOnlyUserDevices, + pub(crate) inner: HashMap, pub(crate) verification_machine: VerificationMachine, pub(crate) own_identity: Option, pub(crate) device_owner_identity: Option, @@ -168,7 +188,7 @@ impl UserDevices { /// Get the specific device with the given device id. pub fn get(&self, device_id: &DeviceId) -> Option { self.inner.get(device_id).map(|d| Device { - inner: d, + inner: d.clone(), verification_machine: self.verification_machine.clone(), own_identity: self.own_identity.clone(), device_owner_identity: self.device_owner_identity.clone(), @@ -176,13 +196,13 @@ impl UserDevices { } /// Iterator over all the device ids of the user devices. - pub fn keys(&self) -> impl Iterator { + pub fn keys(&self) -> impl Iterator { self.inner.keys() } /// Iterator over all the devices of the user devices. pub fn devices(&self) -> impl Iterator + '_ { - self.inner.devices().map(move |d| Device { + self.inner.values().map(move |d| Device { inner: d.clone(), verification_machine: self.verification_machine.clone(), own_identity: self.own_identity.clone(), @@ -352,7 +372,7 @@ impl ReadOnlyDevice { store: &dyn CryptoStore, event_type: EventType, content: Value, - ) -> OlmResult { + ) -> OlmResult<(Session, EncryptedEventContent)> { let sender_key = if let Some(k) = self.get_key(DeviceKeyAlgorithm::Curve25519) { k } else { @@ -384,10 +404,9 @@ impl ReadOnlyDevice { return Err(OlmError::MissingSession); }; - let message = session.encrypt(&self, event_type, content).await; - store.save_sessions(&[session]).await?; + let message = session.encrypt(&self, event_type, content).await?; - message + Ok((session, message)) } /// Update a device with a new device keys struct. diff --git a/matrix_sdk_crypto/src/identities/manager.rs b/matrix_sdk_crypto/src/identities/manager.rs index d80c8c2e8..993dd5a4a 100644 --- a/matrix_sdk_crypto/src/identities/manager.rs +++ b/matrix_sdk_crypto/src/identities/manager.rs @@ -27,13 +27,13 @@ use matrix_sdk_common::{ use crate::{ error::OlmResult, - group_manager::GroupSessionManager, identities::{ MasterPubkey, OwnUserIdentity, ReadOnlyDevice, SelfSigningPubkey, UserIdentities, UserIdentity, UserSigningPubkey, }, requests::KeysQueryRequest, - store::{Result as StoreResult, Store}, + session_manager::GroupSessionManager, + store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, }; #[derive(Debug, Clone)] @@ -79,7 +79,7 @@ impl IdentityManager { pub async fn receive_keys_query_response( &self, response: &KeysQueryResponse, - ) -> OlmResult<(Vec, Vec)> { + ) -> OlmResult<(DeviceChanges, IdentityChanges)> { // TODO create a enum that tells us how the device/identity changed, // e.g. new/deleted/display name change. // @@ -92,9 +92,15 @@ impl IdentityManager { let changed_devices = self .handle_devices_from_key_query(&response.device_keys) .await?; - self.store.save_devices(&changed_devices).await?; let changed_identities = self.handle_cross_singing_keys(response).await?; - self.store.save_user_identities(&changed_identities).await?; + + let changes = Changes { + identities: changed_identities.clone(), + devices: changed_devices.clone(), + ..Default::default() + }; + + self.store.save_changes(changes).await?; Ok((changed_devices, changed_identities)) } @@ -111,9 +117,10 @@ impl IdentityManager { async fn handle_devices_from_key_query( &self, device_keys_map: &BTreeMap>, - ) -> StoreResult> { + ) -> StoreResult { let mut users_with_new_or_deleted_devices = HashSet::new(); - let mut changed_devices = Vec::new(); + + let mut changes = DeviceChanges::default(); for (user_id, device_map) in device_keys_map { // TODO move this out into the handle keys query response method @@ -137,7 +144,7 @@ impl IdentityManager { let device = self.store.get_readonly_device(&user_id, device_id).await?; - let device = if let Some(mut device) = device { + if let Some(mut device) = device { if let Err(e) = device.update_device(device_keys) { warn!( "Failed to update the device keys for {} {}: {:?}", @@ -145,7 +152,7 @@ impl IdentityManager { ); continue; } - device + changes.changed.push(device); } else { let device = match ReadOnlyDevice::try_from(device_keys) { Ok(d) => d, @@ -159,24 +166,21 @@ impl IdentityManager { }; info!("Adding a new device to the device store {:?}", device); users_with_new_or_deleted_devices.insert(user_id); - device - }; - - changed_devices.push(device); + changes.new.push(device); + } } - let current_devices: HashSet<&DeviceId> = - device_map.keys().map(|id| id.as_ref()).collect(); + let current_devices: HashSet<&DeviceIdBox> = device_map.keys().collect(); let stored_devices = self.store.get_readonly_devices(&user_id).await?; - let stored_devices_set: HashSet<&DeviceId> = stored_devices.keys().collect(); + let stored_devices_set: HashSet<&DeviceIdBox> = stored_devices.keys().collect(); - let deleted_devices = stored_devices_set.difference(¤t_devices); + let deleted_devices_set = stored_devices_set.difference(¤t_devices); - for device_id in deleted_devices { + for device_id in deleted_devices_set { users_with_new_or_deleted_devices.insert(user_id); - if let Some(device) = stored_devices.get(device_id) { + if let Some(device) = stored_devices.get(*device_id) { device.mark_as_deleted(); - self.store.delete_device(device).await?; + changes.deleted.push(device.clone()); } } } @@ -184,7 +188,7 @@ impl IdentityManager { self.group_manager .invalidate_sessions_new_devices(&users_with_new_or_deleted_devices); - Ok(changed_devices) + Ok(changes) } /// Handle the device keys part of a key query response. @@ -198,8 +202,8 @@ impl IdentityManager { async fn handle_cross_singing_keys( &self, response: &KeysQueryResponse, - ) -> StoreResult> { - let mut changed = Vec::new(); + ) -> StoreResult { + let mut changes = IdentityChanges::default(); for (user_id, master_key) in &response.master_keys { let master_key = MasterPubkey::from(master_key); @@ -214,7 +218,7 @@ impl IdentityManager { continue; }; - let identity = if let Some(mut i) = self.store.get_user_identity(user_id).await? { + let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? { match &mut i { UserIdentities::Own(ref mut identity) => { let user_signing = if let Some(s) = response.user_signing_keys.get(user_id) @@ -231,11 +235,11 @@ impl IdentityManager { identity .update(master_key, self_signing, user_signing) - .map(|_| i) - } - UserIdentities::Other(ref mut identity) => { - identity.update(master_key, self_signing).map(|_| i) + .map(|_| (i, false)) } + UserIdentities::Other(ref mut identity) => identity + .update(master_key, self_signing) + .map(|_| (i, false)), } } else if user_id == self.user_id() { if let Some(s) = response.user_signing_keys.get(user_id) { @@ -253,7 +257,7 @@ impl IdentityManager { } OwnUserIdentity::new(master_key, self_signing, user_signing) - .map(UserIdentities::Own) + .map(|i| (UserIdentities::Own(i), true)) } else { warn!( "User identity for our own user {} didn't contain a \ @@ -269,17 +273,22 @@ impl IdentityManager { ); continue; } else { - UserIdentity::new(master_key, self_signing).map(UserIdentities::Other) + UserIdentity::new(master_key, self_signing) + .map(|i| (UserIdentities::Other(i), true)) }; - match identity { - Ok(i) => { + match result { + Ok((i, new)) => { trace!( "Updated or created new user identity for {}: {:?}", user_id, i ); - changed.push(i); + if new { + changes.new.push(i); + } else { + changes.changed.push(i); + } } Err(e) => { warn!( @@ -291,7 +300,7 @@ impl IdentityManager { } } - Ok(changed) + Ok(changes) } /// Get a key query request if one is needed. @@ -369,6 +378,7 @@ pub(crate) mod test { use matrix_sdk_common::{ api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId}, + locks::Mutex, }; use matrix_sdk_test::async_test; @@ -376,10 +386,10 @@ pub(crate) mod test { use serde_json::json; use crate::{ - group_manager::GroupSessionManager, identities::IdentityManager, machine::test::response_from_file, - olm::{Account, ReadOnlyAccount}, + olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionManager, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -401,10 +411,11 @@ pub(crate) mod test { } fn manager() -> IdentityManager { + let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id()))); let user_id = Arc::new(user_id()); let account = ReadOnlyAccount::new(&user_id, &device_id()); let store: Arc> = Arc::new(Box::new(MemoryStore::new())); - let verification = VerificationMachine::new(account.clone(), store); + let verification = VerificationMachine::new(account.clone(), identity, store); let store = Store::new( user_id.clone(), Arc::new(Box::new(MemoryStore::new())), diff --git a/matrix_sdk_crypto/src/identities/user.rs b/matrix_sdk_crypto/src/identities/user.rs index a23dbc96a..70026de05 100644 --- a/matrix_sdk_crypto/src/identities/user.rs +++ b/matrix_sdk_crypto/src/identities/user.rs @@ -86,6 +86,24 @@ impl From for UserSigningPubkey { } } +impl Into for MasterPubkey { + fn into(self) -> CrossSigningKey { + self.0.as_ref().clone() + } +} + +impl Into for UserSigningPubkey { + fn into(self) -> CrossSigningKey { + self.0.as_ref().clone() + } +} + +impl Into for SelfSigningPubkey { + fn into(self) -> CrossSigningKey { + self.0.as_ref().clone() + } +} + impl AsRef for MasterPubkey { fn as_ref(&self) -> &CrossSigningKey { &self.0 @@ -135,7 +153,7 @@ impl<'a> From<&'a UserSigningPubkey> for CrossSigningSubKeys<'a> { } /// Enum over the cross signing sub-keys. -enum CrossSigningSubKeys<'a> { +pub(crate) enum CrossSigningSubKeys<'a> { /// The self signing subkey. SelfSigning(&'a SelfSigningPubkey), /// The user signing subkey. @@ -152,7 +170,7 @@ impl<'a> CrossSigningSubKeys<'a> { } /// Get the `CrossSigningKey` from an sub-keys enum - fn cross_signing_key(&self) -> &CrossSigningKey { + pub(crate) fn cross_signing_key(&self) -> &CrossSigningKey { match self { CrossSigningSubKeys::SelfSigning(key) => &key.0, CrossSigningSubKeys::UserSigning(key) => &key.0, @@ -198,7 +216,7 @@ impl MasterPubkey { /// /// Returns an empty result if the signature check succeeded, otherwise a /// SignatureError indicating why the check failed. - fn verify_subkey<'a>( + pub(crate) fn verify_subkey<'a>( &self, subkey: impl Into>, ) -> Result<(), SignatureError> { @@ -666,13 +684,13 @@ pub(crate) mod test { manager::test::{other_key_query, own_key_query}, Device, ReadOnlyDevice, }, - olm::ReadOnlyAccount, + olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, store::MemoryStore, verification::VerificationMachine, }; use matrix_sdk_common::{ - api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, + api::r0::keys::get_keys::Response as KeyQueryResponse, identifiers::user_id, locks::Mutex, }; use super::{OwnUserIdentity, UserIdentities, UserIdentity}; @@ -734,8 +752,12 @@ pub(crate) mod test { assert!(identity.is_device_signed(&first).is_err()); assert!(identity.is_device_signed(&second).is_ok()); + let private_identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty( + second.user_id().clone(), + ))); let verification_machine = VerificationMachine::new( ReadOnlyAccount::new(second.user_id(), second.device_id()), + private_identity, Arc::new(Box::new(MemoryStore::new())), ); diff --git a/matrix_sdk_crypto/src/key_request.rs b/matrix_sdk_crypto/src/key_request.rs index 84f4eab27..ba98f62e5 100644 --- a/matrix_sdk_crypto/src/key_request.rs +++ b/matrix_sdk_crypto/src/key_request.rs @@ -41,7 +41,7 @@ use matrix_sdk_common::{ use crate::{ error::{OlmError, OlmResult}, - olm::{InboundGroupSession, OutboundGroupSession}, + olm::{InboundGroupSession, OutboundGroupSession, Session}, requests::{OutgoingRequest, ToDeviceRequest}, store::{CryptoStoreError, Store}, Device, @@ -196,6 +196,7 @@ impl KeyRequestMachine { device_id: Arc, store: Store, outbound_group_sessions: Arc>, + users_for_key_claim: Arc>>, ) -> Self { Self { user_id, @@ -205,7 +206,7 @@ impl KeyRequestMachine { outgoing_to_device_requests: Arc::new(DashMap::new()), incoming_key_requests: Arc::new(DashMap::new()), wait_queue: WaitQueue::new(), - users_for_key_claim: Arc::new(DashMap::new()), + users_for_key_claim, } } @@ -214,11 +215,6 @@ impl KeyRequestMachine { &self.user_id } - /// Get the map of user/devices which we need to claim one-time for. - pub fn users_for_key_claim(&self) -> &DashMap> { - &self.users_for_key_claim - } - pub fn outgoing_to_device_requests(&self) -> Vec { #[allow(clippy::map_clone)] self.outgoing_to_device_requests @@ -239,15 +235,18 @@ impl KeyRequestMachine { /// Handle all the incoming key requests that are queued up and empty our /// key request queue. - pub async fn collect_incoming_key_requests(&self) -> OlmResult<()> { + pub async fn collect_incoming_key_requests(&self) -> OlmResult> { + let mut changed_sessions = Vec::new(); for item in self.incoming_key_requests.iter() { let event = item.value(); - self.handle_key_request(event).await?; + if let Some(s) = self.handle_key_request(event).await? { + changed_sessions.push(s); + } } self.incoming_key_requests.clear(); - Ok(()) + Ok(changed_sessions) } /// Store the key share request for later, once we get an Olm session with @@ -298,7 +297,7 @@ impl KeyRequestMachine { async fn handle_key_request( &self, event: &ToDeviceEvent, - ) -> OlmResult<()> { + ) -> OlmResult> { let key_info = match event.content.action { Action::Request => { if let Some(info) = &event.content.body { @@ -309,11 +308,11 @@ impl KeyRequestMachine { action, but no key info was found", event.sender, event.content.requesting_device_id ); - return Ok(()); + return Ok(None); } } // We ignore cancellations here since there's nothing to serve. - Action::CancelRequest => return Ok(()), + Action::CancelRequest => return Ok(None), }; let session = self @@ -332,7 +331,7 @@ impl KeyRequestMachine { "Received a key request from {} {} for an unknown inbound group session {}.", &event.sender, &event.content.requesting_device_id, &key_info.session_id ); - return Ok(()); + return Ok(None); }; let device = self @@ -353,6 +352,8 @@ impl KeyRequestMachine { device.device_id(), e ); + + Ok(None) } else { info!( "Serving a key request for {} from {} {}.", @@ -361,20 +362,20 @@ impl KeyRequestMachine { device.device_id() ); - if let Err(e) = self.share_session(&session, &device).await { - match e { - OlmError::MissingSession => { - info!( - "Key request from {} {} is missing an Olm session, \ - putting the request in the wait queue", - device.user_id(), - device.device_id() - ); - self.handle_key_share_without_session(device, event); - return Ok(()); - } - e => return Err(e), + match self.share_session(&session, &device).await { + Ok(s) => Ok(Some(s)), + Err(OlmError::MissingSession) => { + info!( + "Key request from {} {} is missing an Olm session, \ + putting the request in the wait queue", + device.user_id(), + device.device_id() + ); + self.handle_key_share_without_session(device, event); + + Ok(None) } + Err(e) => Err(e), } } } else { @@ -383,13 +384,17 @@ impl KeyRequestMachine { &event.sender, &event.content.requesting_device_id ); self.store.update_tracked_user(&event.sender, true).await?; - } - Ok(()) + Ok(None) + } } - async fn share_session(&self, session: &InboundGroupSession, device: &Device) -> OlmResult<()> { - let content = device.encrypt_session(session.clone()).await?; + async fn share_session( + &self, + session: &InboundGroupSession, + device: &Device, + ) -> OlmResult { + let (used_session, content) = device.encrypt_session(session.clone()).await?; let id = Uuid::new_v4(); let mut messages = BTreeMap::new(); @@ -416,7 +421,7 @@ impl KeyRequestMachine { self.outgoing_to_device_requests.insert(id, request); - Ok(()) + Ok(used_session) } /// Check if it's ok to share a session with the given device. @@ -573,23 +578,20 @@ impl KeyRequestMachine { Ok(()) } - /// Save an inbound group session we received using a key forward. + /// Mark the given outgoing key info as done. /// - /// At the same time delete the key info since we received the wanted key. - async fn save_session( - &self, - key_info: OugoingKeyInfo, - session: InboundGroupSession, - ) -> Result<(), CryptoStoreError> { + /// This will queue up a request cancelation. + async fn mark_as_done(&self, key_info: OugoingKeyInfo) -> Result<(), CryptoStoreError> { // TODO perhaps only remove the key info if the first known index is 0. trace!( "Successfully received a forwarded room key for {:#?}", key_info ); - self.store.save_inbound_group_sessions(&[session]).await?; self.outgoing_to_device_requests .remove(&key_info.request_id); + // TODO return the key info instead of deleting it so the sync handler + // can delete it in one transaction. self.delete_key_info(&key_info).await?; let content = RoomKeyRequestEventContent { @@ -613,7 +615,8 @@ impl KeyRequestMachine { &self, sender_key: &str, event: &mut ToDeviceEvent, - ) -> Result>, CryptoStoreError> { + ) -> Result<(Option>, Option), CryptoStoreError> + { let key_info = self.get_key_info(&event.content).await?; if let Some(info) = key_info { @@ -630,27 +633,32 @@ impl KeyRequestMachine { // If we have a previous session, check if we have a better version // and store the new one if so. - if let Some(old_session) = old_session { + let session = if let Some(old_session) = old_session { let first_old_index = old_session.first_known_index().await; let first_index = session.first_known_index().await; if first_old_index > first_index { - self.save_session(info, session).await?; + self.mark_as_done(info).await?; + Some(session) + } else { + None } // If we didn't have a previous session, store it. } else { - self.save_session(info, session).await?; - } + self.mark_as_done(info).await?; + Some(session) + }; - Ok(Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey( - event.clone(), - )))) + Ok(( + Some(Raw::from(AnyToDeviceEvent::ForwardedRoomKey(event.clone()))), + session, + )) } else { info!( "Received a forwarded room key from {}, but no key info was found.", event.sender, ); - Ok(None) + Ok((None, None)) } } } @@ -666,13 +674,14 @@ mod test { AnyToDeviceEvent, ToDeviceEvent, }, identifiers::{room_id, user_id, DeviceIdBox, RoomId, UserId}, + locks::Mutex, }; use matrix_sdk_test::async_test; use std::{convert::TryInto, sync::Arc}; use crate::{ identities::{LocalTrust, ReadOnlyDevice}, - olm::{Account, ReadOnlyAccount}, + olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, store::{CryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -711,7 +720,8 @@ mod test { let user_id = Arc::new(bob_id()); let account = ReadOnlyAccount::new(&user_id, &alice_device_id()); let store: Arc> = Arc::new(Box::new(MemoryStore::new())); - let verification = VerificationMachine::new(account, store.clone()); + let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(bob_id()))); + let verification = VerificationMachine::new(account, identity, store.clone()); let store = Store::new(user_id.clone(), store, verification); KeyRequestMachine::new( @@ -719,6 +729,7 @@ mod test { Arc::new(bob_device_id()), store, Arc::new(DashMap::new()), + Arc::new(DashMap::new()), ) } @@ -727,7 +738,8 @@ mod test { let account = ReadOnlyAccount::new(&user_id, &alice_device_id()); let device = ReadOnlyDevice::from_account(&account).await; let store: Arc> = Arc::new(Box::new(MemoryStore::new())); - let verification = VerificationMachine::new(account, store.clone()); + let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); + let verification = VerificationMachine::new(account, identity, store.clone()); let store = Store::new(user_id.clone(), store, verification); store.save_devices(&[device]).await.unwrap(); @@ -736,6 +748,7 @@ mod test { Arc::new(alice_device_id()), store, Arc::new(DashMap::new()), + Arc::new(DashMap::new()), ) } @@ -833,20 +846,20 @@ mod test { .is_none() ); - machine + let (_, first_session) = machine .receive_forwarded_room_key(&session.sender_key, &mut event) .await .unwrap(); - - let first_session = machine - .store - .get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); + let first_session = first_session.unwrap(); assert_eq!(first_session.first_known_index().await, 10); + machine + .store + .save_inbound_group_sessions(&[first_session.clone()]) + .await + .unwrap(); + // Get the cancel request. let request = machine.outgoing_to_device_requests.iter().next().unwrap(); let id = request.request_id; @@ -877,19 +890,12 @@ mod test { content, }; - machine + let (_, second_session) = machine .receive_forwarded_room_key(&session.sender_key, &mut event) .await .unwrap(); - let second_session = machine - .store - .get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); - - assert_eq!(second_session.first_known_index().await, 10); + assert!(second_session.is_none()); let export = session.export_at_index(0).await.unwrap(); @@ -900,18 +906,12 @@ mod test { content, }; - machine + let (_, second_session) = machine .receive_forwarded_room_key(&session.sender_key, &mut event) .await .unwrap(); - let second_session = machine - .store - .get_inbound_group_session(session.room_id(), &session.sender_key, session.session_id()) - .await - .unwrap() - .unwrap(); - assert_eq!(second_session.first_known_index().await, 0); + assert_eq!(second_session.unwrap().first_known_index().await, 0); } #[async_test] @@ -1134,14 +1134,19 @@ mod test { .unwrap() .is_none()); - let (decrypted, sender_key, _) = + let (_, decrypted, sender_key, _) = alice_account.decrypt_to_device_event(&event).await.unwrap(); if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { - alice_machine + let (_, session) = alice_machine .receive_forwarded_room_key(&sender_key, &mut e) .await .unwrap(); + alice_machine + .store + .save_inbound_group_sessions(&[session.unwrap()]) + .await + .unwrap(); } else { panic!("Invalid decrypted event type"); } @@ -1317,14 +1322,19 @@ mod test { .unwrap() .is_none()); - let (decrypted, sender_key, _) = + let (_, decrypted, sender_key, _) = alice_account.decrypt_to_device_event(&event).await.unwrap(); if let AnyToDeviceEvent::ForwardedRoomKey(mut e) = decrypted.deserialize().unwrap() { - alice_machine + let (_, session) = alice_machine .receive_forwarded_room_key(&sender_key, &mut e) .await .unwrap(); + alice_machine + .store + .save_inbound_group_sessions(&[session.unwrap()]) + .await + .unwrap(); } else { panic!("Invalid decrypted event type"); } diff --git a/matrix_sdk_crypto/src/lib.rs b/matrix_sdk_crypto/src/lib.rs index 6f1681266..0a1f0300e 100644 --- a/matrix_sdk_crypto/src/lib.rs +++ b/matrix_sdk_crypto/src/lib.rs @@ -29,7 +29,6 @@ mod error; mod file_encryption; -mod group_manager; mod identities; mod key_request; mod machine; diff --git a/matrix_sdk_crypto/src/machine.rs b/matrix_sdk_crypto/src/machine.rs index cbe0acbf0..b3d5589a0 100644 --- a/matrix_sdk_crypto/src/machine.rs +++ b/matrix_sdk_crypto/src/machine.rs @@ -17,6 +17,7 @@ use std::path::Path; use std::{collections::BTreeMap, mem, sync::Arc}; use dashmap::DashMap; +use matrix_sdk_common::locks::Mutex; use tracing::{debug, error, info, instrument, trace, warn}; use matrix_sdk_common::{ @@ -25,6 +26,7 @@ use matrix_sdk_common::{ claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, get_keys::Response as KeysQueryResponse, upload_keys, + upload_signatures::Request as UploadSignaturesRequest, }, sync::sync_events::Response as SyncResponse, }, @@ -45,17 +47,19 @@ use matrix_sdk_common::{ #[cfg(feature = "sqlite_cryptostore")] use crate::store::sqlite::SqliteStore; use crate::{ - error::{EventError, MegolmError, MegolmResult, OlmResult}, - group_manager::GroupSessionManager, - identities::{Device, IdentityManager, ReadOnlyDevice, UserDevices, UserIdentities}, + error::{EventError, MegolmError, MegolmResult, OlmError, OlmResult}, + identities::{Device, IdentityManager, UserDevices}, key_request::KeyRequestMachine, olm::{ Account, EncryptionSettings, ExportedRoomKey, GroupSessionKey, IdentityKeys, - InboundGroupSession, ReadOnlyAccount, + InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, + }, + requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, + session_manager::{GroupSessionManager, SessionManager}, + store::{ + Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, + Store, }, - requests::{IncomingResponse, OutgoingRequest}, - session_manager::SessionManager, - store::{CryptoStore, MemoryStore, Result as StoreResult, Store}, verification::{Sas, VerificationMachine}, ToDeviceRequest, }; @@ -70,6 +74,11 @@ pub struct OlmMachine { device_id: Arc>, /// Our underlying Olm Account holding our identity keys. account: Account, + /// The private part of our cross signing identity. + /// Used to sign devices and other users, might be missing if some other + /// device bootstraped cross signing or cross signing isn't bootstrapped at + /// all. + user_identity: Arc>, /// Store for the encryption keys. /// Persists all the encryption keys so a client can resume the session /// without the need to create new keys. @@ -87,6 +96,7 @@ pub struct OlmMachine { /// State machine handling public user identities and devices, keeping track /// of when a key query needs to be done and handling one. identity_manager: IdentityManager, + cross_signing_request: Arc>>, } #[cfg(not(tarpaulin_include))] @@ -115,7 +125,13 @@ impl OlmMachine { let device_id: DeviceIdBox = device_id.into(); let account = ReadOnlyAccount::new(&user_id, &device_id); - OlmMachine::new_helper(user_id, device_id, store, account) + OlmMachine::new_helper( + user_id, + device_id, + store, + account, + PrivateCrossSigningIdentity::empty(user_id.to_owned()), + ) } fn new_helper( @@ -123,19 +139,25 @@ impl OlmMachine { device_id: DeviceIdBox, store: Box, account: ReadOnlyAccount, + user_identity: PrivateCrossSigningIdentity, ) -> Self { let user_id = Arc::new(user_id.clone()); + let user_identity = Arc::new(Mutex::new(user_identity)); let store = Arc::new(store); - let verification_machine = VerificationMachine::new(account.clone(), store.clone()); + let verification_machine = + VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); let store = Store::new(user_id.clone(), store, verification_machine.clone()); let device_id: Arc = Arc::new(device_id); let outbound_group_sessions = Arc::new(DashMap::new()); + let users_for_key_claim = Arc::new(DashMap::new()); + let key_request_machine = KeyRequestMachine::new( user_id.clone(), device_id.clone(), store.clone(), outbound_group_sessions, + users_for_key_claim.clone(), ); let account = Account { @@ -143,8 +165,12 @@ impl OlmMachine { store: store.clone(), }; - let session_manager = - SessionManager::new(account.clone(), key_request_machine.clone(), store.clone()); + let session_manager = SessionManager::new( + account.clone(), + users_for_key_claim, + key_request_machine.clone(), + store.clone(), + ); let group_session_manager = GroupSessionManager::new(account.clone(), store.clone()); let identity_manager = IdentityManager::new( user_id.clone(), @@ -157,12 +183,14 @@ impl OlmMachine { user_id, device_id, account, + user_identity, store, session_manager, group_session_manager, verification_machine, key_request_machine, identity_manager, + cross_signing_request: Arc::new(Mutex::new(None)), } } @@ -197,11 +225,26 @@ impl OlmMachine { } None => { debug!("Creating a new account"); - ReadOnlyAccount::new(&user_id, &device_id) + let account = ReadOnlyAccount::new(&user_id, &device_id); + store.save_account(account.clone()).await?; + account } }; - Ok(OlmMachine::new_helper(&user_id, device_id, store, account)) + let identity = match store.load_identity().await? { + Some(i) => { + debug!("Restored the cross signing identity"); + i + } + None => { + debug!("Creating an empty cross signing identity stub"); + PrivateCrossSigningIdentity::empty(user_id.clone()) + } + }; + + Ok(OlmMachine::new_helper( + &user_id, device_id, store, account, identity, + )) } /// Create a new machine with the default crypto store. @@ -305,11 +348,58 @@ impl OlmMachine { IncomingResponse::ToDevice(_) => { self.mark_to_device_request_as_sent(&request_id).await?; } + IncomingResponse::SigningKeysUpload(_) => { + self.receive_cross_signing_upload_response().await?; + } }; Ok(()) } + /// Mark the cross signing identity as shared. + async fn receive_cross_signing_upload_response(&self) -> StoreResult<()> { + self.user_identity.lock().await.mark_as_shared(); + self.store + .save_identity((&*self.user_identity.lock().await).clone()) + .await + } + + /// Create a new cross signing identity and get the upload request to push + /// the new public keys to the server. + /// + /// **Warning**: This will delete any existing cross signing keys that might + /// exist on the server and thus will reset the trust between all the + /// devices. + /// + /// Uploading these keys will require user interactive auth. + pub async fn bootstrap_cross_signing( + &self, + reset: bool, + ) -> StoreResult<(UploadSigningKeysRequest, UploadSignaturesRequest)> { + let mut identity = self.user_identity.lock().await; + + if identity.is_empty().await || reset { + info!("Creating new cross signing identity"); + let (id, signature_request) = self.account.bootstrap_cross_signing().await; + let request = id.as_upload_request().await; + + *identity = id; + + self.store.save_identity(identity.clone()).await?; + Ok((request, signature_request)) + } else { + info!("Trying to upload the existing cross signing identity"); + let request = identity.as_upload_request().await; + let device_keys = self.account.unsigned_device_keys(); + // TODO remove this expect. + let signature_request = identity + .sign_device(device_keys) + .await + .expect("Can't sign device keys"); + Ok((request, signature_request)) + } + } + /// Should device or one-time keys be uploaded to the server. /// /// This needs to be checked periodically, ideally after every sync request. @@ -417,7 +507,7 @@ impl OlmMachine { async fn receive_keys_query_response( &self, response: &KeysQueryResponse, - ) -> OlmResult<(Vec, Vec)> { + ) -> OlmResult<(DeviceChanges, IdentityChanges)> { self.identity_manager .receive_keys_query_response(response) .await @@ -448,12 +538,12 @@ impl OlmMachine { async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult> { - let (decrypted_event, sender_key, signing_key) = + ) -> OlmResult<(Session, Raw, Option)> { + let (session, decrypted_event, sender_key, signing_key) = self.account.decrypt_to_device_event(event).await?; // Handle the decrypted event, e.g. fetch out Megolm sessions out of // the event. - if let Some(event) = self + if let (Some(event), group_session) = self .handle_decrypted_to_device_event(&sender_key, &signing_key, &decrypted_event) .await? { @@ -462,9 +552,9 @@ impl OlmMachine { // don't want them to be able to do silly things with it. Handling // events modifies them and returns a modified one, so replace it // here if we get one. - Ok(event) + Ok((session, event, group_session)) } else { - Ok(decrypted_event) + Ok((session, decrypted_event, None)) } } @@ -474,7 +564,7 @@ impl OlmMachine { sender_key: &str, signing_key: &str, event: &mut ToDeviceEvent, - ) -> OlmResult>> { + ) -> OlmResult<(Option>, Option)> { match event.content.algorithm { EventEncryptionAlgorithm::MegolmV1AesSha2 => { let session_key = GroupSessionKey(mem::take(&mut event.content.session_key)); @@ -485,17 +575,15 @@ impl OlmMachine { &event.content.room_id, session_key, )?; - let _ = self.store.save_inbound_group_sessions(&[session]).await?; - let event = Raw::from(AnyToDeviceEvent::RoomKey(event.clone())); - Ok(Some(event)) + Ok((Some(event), Some(session))) } _ => { warn!( "Received room key with unsupported key algorithm {}", event.content.algorithm ); - Ok(None) + Ok((None, None)) } } } @@ -505,9 +593,14 @@ impl OlmMachine { &self, room_id: &RoomId, ) -> OlmResult<()> { - self.group_session_manager + let (_, session) = self + .group_session_manager .create_outbound_group_session(room_id, EncryptionSettings::default()) - .await + .await?; + + self.store.save_inbound_group_sessions(&[session]).await?; + + Ok(()) } /// Encrypt a room message for the given room. @@ -597,12 +690,12 @@ impl OlmMachine { sender_key: &str, signing_key: &str, event: &Raw, - ) -> OlmResult>> { + ) -> OlmResult<(Option>, Option)> { let event = if let Ok(e) = event.deserialize() { e } else { warn!("Decrypted to-device event failed to be parsed correctly"); - return Ok(None); + return Ok((None, None)); }; match event { @@ -615,7 +708,7 @@ impl OlmMachine { .await?), _ => { warn!("Received a unexpected encrypted to-device event"); - Ok(None) + Ok((None, None)) } } } @@ -638,6 +731,8 @@ impl OlmMachine { .mark_outgoing_request_as_sent(request_id) .await?; self.group_session_manager.mark_request_as_sent(request_id); + self.session_manager + .mark_outgoing_request_as_sent(request_id); Ok(()) } @@ -647,11 +742,8 @@ impl OlmMachine { self.verification_machine.get_sas(flow_id) } - async fn update_one_time_key_count( - &self, - key_count: &BTreeMap, - ) -> StoreResult<()> { - self.account.update_uploaded_key_count(key_count).await + async fn update_one_time_key_count(&self, key_count: &BTreeMap) { + self.account.update_uploaded_key_count(key_count).await; } /// Handle a sync response and update the internal state of the Olm machine. @@ -667,15 +759,19 @@ impl OlmMachine { /// /// [`decrypt_room_event`]: #method.decrypt_room_event #[instrument(skip(response))] - pub async fn receive_sync_response(&self, response: &mut SyncResponse) { + pub async fn receive_sync_response(&self, response: &mut SyncResponse) -> OlmResult<()> { + // Remove verification objects that have expired or are done. self.verification_machine.garbage_collect(); - if let Err(e) = self - .update_one_time_key_count(&response.device_one_time_keys_count) - .await - { - error!("Error updating the one-time key count {:?}", e); - } + // Always save the account, a new session might get created which also + // touches the account. + let mut changes = Changes { + account: Some(self.account.inner.clone()), + ..Default::default() + }; + + self.update_one_time_key_count(&response.device_one_time_keys_count) + .await; for user_id in &response.device_lists.changed { if let Err(e) = self.identity_manager.mark_user_as_changed(&user_id).await { @@ -696,18 +792,36 @@ impl OlmMachine { match &mut event { AnyToDeviceEvent::RoomEncrypted(e) => { - let decrypted_event = match self.decrypt_to_device_event(e).await { - Ok(e) => e, - Err(err) => { - warn!( - "Failed to decrypt to-device event from {} {}", - e.sender, err - ); - // TODO if the session is wedged mark it for - // unwedging. - continue; - } - }; + let (session, decrypted_event, group_session) = + match self.decrypt_to_device_event(e).await { + Ok(e) => e, + Err(err) => { + warn!( + "Failed to decrypt to-device event from {} {}", + e.sender, err + ); + + if let OlmError::SessionWedged(sender, curve_key) = err { + if let Err(e) = self + .session_manager + .mark_device_as_wedged(&sender, &curve_key) + .await + { + error!( + "Couldn't mark device from {} to be unwedged {:?}", + sender, e + ); + } + } + continue; + } + }; + + changes.sessions.push(session); + + if let Some(group_session) = group_session { + changes.inbound_group_sessions.push(group_session); + } *event_result = decrypted_event; } @@ -726,13 +840,14 @@ impl OlmMachine { } } - if let Err(e) = self + let changed_sessions = self .key_request_machine .collect_incoming_key_requests() - .await - { - error!("Error collecting our key share requests {:?}", e); - } + .await?; + + changes.sessions.extend(changed_sessions); + + Ok(self.store.save_changes(changes).await?) } /// Decrypt an event from a room timeline. @@ -887,6 +1002,7 @@ impl OlmMachine { // Only import the session if we didn't have this session or if it's // a better version of the same session, that is the first known // index is lower. + // TODO load all sessions so we don't do a thousand small loads. if let Some(existing_session) = self .store .get_inbound_group_session( @@ -909,7 +1025,17 @@ impl OlmMachine { let num_sessions = sessions.len(); - self.store.save_inbound_group_sessions(&sessions).await?; + let changes = Changes { + inbound_group_sessions: sessions, + ..Default::default() + }; + + self.store.save_changes(changes).await?; + + info!( + "Successfully imported {} inbound group sessions", + num_sessions + ); Ok(num_sessions) } @@ -1130,15 +1256,19 @@ pub(crate) mod test { .unwrap() .unwrap(); + let (session, content) = bob_device + .encrypt(EventType::Dummy, json!({})) + .await + .unwrap(); + alice.store.save_sessions(&[session]).await.unwrap(); + let event = ToDeviceEvent { sender: alice.user_id().clone(), - content: bob_device - .encrypt(EventType::Dummy, json!({})) - .await - .unwrap(), + content, }; - bob.decrypt_to_device_event(&event).await.unwrap(); + let (session, _, _) = bob.decrypt_to_device_event(&event).await.unwrap(); + bob.store.save_sessions(&[session]).await.unwrap(); (alice, bob) } @@ -1424,13 +1554,15 @@ pub(crate) mod test { content: bob_device .encrypt(EventType::Dummy, json!({})) .await - .unwrap(), + .unwrap() + .1, }; let event = bob .decrypt_to_device_event(&event) .await .unwrap() + .1 .deserialize() .unwrap(); @@ -1466,12 +1598,14 @@ pub(crate) mod test { .get_outbound_group_session(&room_id) .unwrap(); - let event = bob - .decrypt_to_device_event(&event) + let (session, event, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + + bob.store.save_sessions(&[session]).await.unwrap(); + bob.store + .save_inbound_group_sessions(&[group_session.unwrap()]) .await - .unwrap() - .deserialize() .unwrap(); + let event = event.deserialize().unwrap(); if let AnyToDeviceEvent::RoomKey(event) = event { assert_eq!(&event.sender, alice.user_id()); @@ -1511,7 +1645,11 @@ pub(crate) mod test { content: to_device_requests_to_content(to_device_requests), }; - bob.decrypt_to_device_event(&event).await.unwrap(); + let (_, _, group_session) = bob.decrypt_to_device_event(&event).await.unwrap(); + bob.store + .save_inbound_group_sessions(&[group_session.unwrap()]) + .await + .unwrap(); let plaintext = "It is a secret to everybody"; @@ -1557,7 +1695,7 @@ pub(crate) mod test { } } - #[tokio::test] + #[tokio::test(threaded_scheduler)] #[cfg(feature = "sqlite_cryptostore")] async fn test_machine_with_default_store() { let tmpdir = tempdir().unwrap(); diff --git a/matrix_sdk_crypto/src/olm/account.rs b/matrix_sdk_crypto/src/olm/account.rs index 0cf5e5691..49ce924f7 100644 --- a/matrix_sdk_crypto/src/olm/account.rs +++ b/matrix_sdk_crypto/src/olm/account.rs @@ -30,7 +30,9 @@ use tracing::{debug, trace, warn}; #[cfg(test)] use matrix_sdk_common::events::EventType; use matrix_sdk_common::{ - api::r0::keys::{upload_keys, OneTimeKey, SignedKey}, + api::r0::keys::{ + upload_keys, upload_signatures::Request as SignatureUploadRequest, OneTimeKey, SignedKey, + }, encryption::DeviceKeys, events::{room::encrypted::EncryptedEventContent, AnyToDeviceEvent}, identifiers::{ @@ -50,13 +52,16 @@ use olm_rs::{ }; use crate::{ - error::{EventError, OlmResult, SessionCreationError}, + error::{EventError, OlmResult, SessionCreationError, SignatureError}, identities::ReadOnlyDevice, - store::{Result as StoreResult, Store}, + store::Store, OlmError, }; -use super::{EncryptionSettings, InboundGroupSession, OutboundGroupSession, Session}; +use super::{ + EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity, + Session, +}; #[derive(Debug, Clone)] pub struct Account { @@ -76,7 +81,7 @@ impl Account { pub async fn decrypt_to_device_event( &self, event: &ToDeviceEvent, - ) -> OlmResult<(Raw, String, String)> { + ) -> OlmResult<(Session, Raw, String, String)> { debug!("Decrypting to-device event"); let content = if let EncryptedEventContent::OlmV1Curve25519AesSha2(c) = &event.content { @@ -103,27 +108,28 @@ impl Account { .map_err(|_| EventError::UnsupportedOlmType)?; // Decrypt the OlmMessage and get a Ruma event out of it. - let (decrypted_event, signing_key) = self + let (session, decrypted_event, signing_key) = self .decrypt_olm_message(&event.sender, &content.sender_key, message) .await?; debug!("Decrypted a to-device event {:?}", decrypted_event); - Ok((decrypted_event, content.sender_key.clone(), signing_key)) + Ok(( + session, + decrypted_event, + content.sender_key.clone(), + signing_key, + )) } else { warn!("Olm event doesn't contain a ciphertext for our key"); Err(EventError::MissingCiphertext.into()) } } - pub async fn update_uploaded_key_count( - &self, - key_count: &BTreeMap, - ) -> StoreResult<()> { + pub async fn update_uploaded_key_count(&self, key_count: &BTreeMap) { let one_time_key_count = key_count.get(&DeviceKeyAlgorithm::SignedCurve25519); let count: u64 = one_time_key_count.map_or(0, |c| (*c).into()); self.inner.update_uploaded_key_count(count); - self.store.save_account(self.inner.clone()).await } pub async fn receive_keys_upload_response( @@ -161,7 +167,7 @@ impl Account { sender: &UserId, sender_key: &str, message: &OlmMessage, - ) -> OlmResult> { + ) -> OlmResult> { let s = self.store.get_sessions(sender_key).await?; // We don't have any existing sessions, return early. @@ -171,8 +177,7 @@ impl Account { return Ok(None); }; - let mut session_to_save = None; - let mut plaintext = None; + let mut decrypted: Option<(Session, String)> = None; for session in &mut *sessions.lock().await { let mut matches = false; @@ -191,9 +196,7 @@ impl Account { match ret { Ok(p) => { - plaintext = Some(p); - session_to_save = Some(session.clone()); - + decrypted = Some((session.clone(), p)); break; } Err(e) => { @@ -205,20 +208,16 @@ impl Account { for sender {} and sender_key {} {:?}", sender, sender_key, e ); - return Err(OlmError::SessionWedged); + return Err(OlmError::SessionWedged( + sender.to_owned(), + sender_key.to_owned(), + )); } } } } - if let Some(session) = session_to_save { - // Decryption was successful, save the new ratchet state of the - // session that was used to decrypt the message. - trace!("Saved the new session state for {}", sender); - self.store.save_sessions(&[session]).await?; - } - - Ok(plaintext) + Ok(decrypted) } /// Decrypt an Olm message, creating a new Olm session if possible. @@ -227,15 +226,15 @@ impl Account { sender: &UserId, sender_key: &str, message: OlmMessage, - ) -> OlmResult<(Raw, String)> { + ) -> OlmResult<(Session, Raw, String)> { // First try to decrypt using an existing session. - let plaintext = if let Some(p) = self + let (session, plaintext) = if let Some(d) = self .try_decrypt_olm_message(sender, sender_key, &message) .await? { - // Decryption succeeded, de-structure the plaintext out of the - // Option. - p + // Decryption succeeded, de-structure the session/plaintext out of + // the Option. + d } else { // Decryption failed with every known session, let's try to create a // new session. @@ -248,7 +247,10 @@ impl Account { available sessions {} {}", sender, sender_key ); - return Err(OlmError::SessionWedged); + return Err(OlmError::SessionWedged( + sender.to_owned(), + sender_key.to_owned(), + )); } OlmMessage::PreKey(m) => { @@ -265,13 +267,13 @@ impl Account { from a prekey message: {}", sender, sender_key, e ); - return Err(OlmError::SessionWedged); + return Err(OlmError::SessionWedged( + sender.to_owned(), + sender_key.to_owned(), + )); } }; - // Save the account since we remove the one-time key that - // was used to create this session. - self.store.save_account(self.inner.clone()).await?; session } }; @@ -279,15 +281,23 @@ impl Account { // Decrypt our message, this shouldn't fail since we're using a // newly created Session. let plaintext = session.decrypt(message).await?; - - // Save the new ratcheted state of the session. - self.store.save_sessions(&[session]).await?; - plaintext + (session, plaintext) }; trace!("Successfully decrypted a Olm message: {}", plaintext); - self.parse_decrypted_to_device_event(sender, &plaintext) + let (event, signing_key) = match self.parse_decrypted_to_device_event(sender, &plaintext) { + Ok(r) => r, + Err(e) => { + // We might created a new session but decryption might still + // have failed, store it for the error case here, this is fine + // since we don't expect this to happen often or at all. + self.store.save_sessions(&[session]).await?; + return Err(e); + } + }; + + Ok((session, event, signing_key)) } /// Parse a decrypted Olm message, check that the plaintext and encrypted @@ -613,9 +623,7 @@ impl ReadOnlyAccount { }) } - /// Sign the device keys of the account and return them so they can be - /// uploaded. - pub(crate) async fn device_keys(&self) -> DeviceKeys { + pub(crate) fn unsigned_device_keys(&self) -> DeviceKeys { let identity_keys = self.identity_keys(); let mut keys = BTreeMap::new(); @@ -629,34 +637,41 @@ impl ReadOnlyAccount { identity_keys.ed25519().to_owned(), ); - let device_keys = json!({ - "user_id": (*self.user_id).clone(), - "device_id": (*self.device_id).clone(), - "algorithms": Self::ALGORITHMS, - "keys": keys, - }); - - let mut signatures = BTreeMap::new(); - - let mut signature = BTreeMap::new(); - signature.insert( - DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id), - self.sign_json(&device_keys).await, - ); - signatures.insert((*self.user_id).clone(), signature); - DeviceKeys::new( (*self.user_id).clone(), (*self.device_id).clone(), - vec![ - EventEncryptionAlgorithm::OlmV1Curve25519AesSha2, - EventEncryptionAlgorithm::MegolmV1AesSha2, - ], + Self::ALGORITHMS.iter().map(|a| (&**a).clone()).collect(), keys, - signatures, + BTreeMap::new(), ) } + /// Sign the device keys of the account and return them so they can be + /// uploaded. + pub(crate) async fn device_keys(&self) -> DeviceKeys { + let mut device_keys = self.unsigned_device_keys(); + let jsond_device_keys = serde_json::to_value(&device_keys).unwrap(); + + device_keys + .signatures + .entry(self.user_id().clone()) + .or_insert_with(BTreeMap::new) + .insert( + DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id), + self.sign_json(jsond_device_keys) + .await + .expect("Can't sign own device keys"), + ); + + device_keys + } + + pub(crate) async fn bootstrap_cross_signing( + &self, + ) -> (PrivateCrossSigningIdentity, SignatureUploadRequest) { + PrivateCrossSigningIdentity::new_with_account(self).await + } + /// Convert a JSON value to the canonical representation and sign the JSON /// string. /// @@ -668,20 +683,18 @@ impl ReadOnlyAccount { /// # Panic /// /// Panics if the json value can't be serialized. - pub async fn sign_json(&self, json: &Value) -> String { - let canonical_json = cjson::to_string(json) - .unwrap_or_else(|_| panic!(format!("Can't serialize {} to canonical JSON", json))); - self.sign(&canonical_json).await + pub async fn sign_json(&self, mut json: Value) -> Result { + let json_object = json.as_object_mut().ok_or(SignatureError::NotAnObject)?; + let _ = json_object.remove("unsigned"); + let _ = json_object.remove("signatures"); + + let canonical_json = cjson::to_string(&json)?; + Ok(self.sign(&canonical_json).await) } - /// Generate, sign and prepare one-time keys to be uploaded. - /// - /// If no one-time keys need to be uploaded returns an empty error. - pub(crate) async fn signed_one_time_keys( + pub(crate) async fn signed_one_time_keys_helper( &self, ) -> Result, ()> { - let _ = self.generate_one_time_keys().await?; - let one_time_keys = self.one_time_keys().await; let mut one_time_key_map = BTreeMap::new(); @@ -690,7 +703,10 @@ impl ReadOnlyAccount { "key": key, }); - let signature = self.sign_json(&key_json).await; + let signature = self + .sign_json(key_json) + .await + .expect("Can't sign own one-time keys"); let mut signature_map = BTreeMap::new(); @@ -719,6 +735,16 @@ impl ReadOnlyAccount { Ok(one_time_key_map) } + /// Generate, sign and prepare one-time keys to be uploaded. + /// + /// If no one-time keys need to be uploaded returns an empty error. + pub(crate) async fn signed_one_time_keys( + &self, + ) -> Result, ()> { + let _ = self.generate_one_time_keys().await?; + self.signed_one_time_keys_helper().await + } + /// Create a new session with another account given a one-time key. /// /// Returns the newly created session or a `OlmSessionError` if creating a diff --git a/matrix_sdk_crypto/src/olm/mod.rs b/matrix_sdk_crypto/src/olm/mod.rs index eef28ff19..116f9b828 100644 --- a/matrix_sdk_crypto/src/olm/mod.rs +++ b/matrix_sdk_crypto/src/olm/mod.rs @@ -20,6 +20,7 @@ mod account; mod group_sessions; mod session; +mod signing; mod utility; pub(crate) use account::Account; @@ -31,6 +32,7 @@ pub use group_sessions::{ pub(crate) use group_sessions::{GroupSessionKey, OutboundGroupSession}; pub use olm_rs::{account::IdentityKeys, PicklingMode}; pub use session::{PickledSession, Session, SessionPickle}; +pub use signing::{PickledCrossSigningIdentity, PrivateCrossSigningIdentity}; pub(crate) use utility::Utility; #[cfg(test)] diff --git a/matrix_sdk_crypto/src/olm/signing.rs b/matrix_sdk_crypto/src/olm/signing.rs new file mode 100644 index 000000000..d43325014 --- /dev/null +++ b/matrix_sdk_crypto/src/olm/signing.rs @@ -0,0 +1,772 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(dead_code, missing_docs)] + +use aes_gcm::{ + aead::{generic_array::GenericArray, Aead, NewAead}, + Aes256Gcm, +}; +use base64::{decode_config, encode_config, DecodeError, URL_SAFE_NO_PAD}; +use getrandom::getrandom; +use matrix_sdk_common::{ + encryption::DeviceKeys, + identifiers::{DeviceKeyAlgorithm, DeviceKeyId}, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Error as JsonError, Value}; +use std::{ + collections::BTreeMap, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; +use thiserror::Error; +use zeroize::Zeroizing; + +use olm_rs::{errors::OlmUtilityError, pk::OlmPkSigning, utility::OlmUtility}; + +use matrix_sdk_common::{ + api::r0::keys::{ + upload_signatures::Request as SignatureUploadRequest, CrossSigningKey, KeyUsage, + }, + identifiers::UserId, + locks::Mutex, +}; + +use crate::{ + error::SignatureError, + identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, + requests::UploadSigningKeysRequest, + UserIdentity, +}; + +use crate::ReadOnlyAccount; + +const NONCE_SIZE: usize = 12; + +fn encode>(input: T) -> String { + encode_config(input, URL_SAFE_NO_PAD) +} + +fn decode>(input: T) -> Result, DecodeError> { + decode_config(input, URL_SAFE_NO_PAD) +} + +/// Error type reporting failures in the Signign operations. +#[derive(Debug, Error)] +pub enum SigningError { + /// Error decoding the base64 encoded pickle data. + #[error(transparent)] + Decode(#[from] DecodeError), + + /// Error decrypting the pickled signing seed + #[error("Error decrypting the pickled signign seed")] + Decryption(String), + + /// Error deserializing the pickle data. + #[error(transparent)] + Json(#[from] JsonError), +} + +#[derive(Clone)] +pub struct Signing { + inner: Arc>, + seed: Arc>>, + public_key: PublicSigningKey, +} + +impl std::fmt::Debug for Signing { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Signing") + .field("public_key", &self.public_key.as_str()) + .finish() + } +} + +impl PartialEq for Signing { + fn eq(&self, other: &Signing) -> bool { + self.seed == other.seed + } +} + +#[derive(Clone, PartialEq, Debug)] +struct MasterSigning { + inner: Signing, + public_key: MasterPubkey, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct PickledMasterSigning { + pickle: PickledSigning, + public_key: CrossSigningKey, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct PickledUserSigning { + pickle: PickledSigning, + public_key: CrossSigningKey, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +struct PickledSelfSigning { + pickle: PickledSigning, + public_key: CrossSigningKey, +} + +impl MasterSigning { + async fn pickle(&self, pickle_key: &[u8]) -> PickledMasterSigning { + let pickle = self.inner.pickle(pickle_key).await; + let public_key = self.public_key.clone().into(); + PickledMasterSigning { pickle, public_key } + } + + fn from_pickle(pickle: PickledMasterSigning, pickle_key: &[u8]) -> Result { + let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; + + Ok(Self { + inner, + public_key: pickle.public_key.into(), + }) + } + + async fn sign_subkey<'a>(&self, subkey: &mut CrossSigningKey) { + // TODO create a borrowed version of a cross singing key. + let subkey_wihtout_signatures = CrossSigningKey { + user_id: subkey.user_id.clone(), + keys: subkey.keys.clone(), + usage: subkey.usage.clone(), + signatures: BTreeMap::new(), + }; + + let message = cjson::to_string(&subkey_wihtout_signatures) + .expect("Can't serialize cross signing subkey"); + let signature = self.inner.sign(&message).await; + + subkey + .signatures + .entry(self.public_key.user_id().to_owned()) + .or_insert_with(BTreeMap::new) + .insert( + format!("ed25519:{}", self.inner.public_key().as_str()), + signature.0, + ); + } +} + +impl UserSigning { + async fn pickle(&self, pickle_key: &[u8]) -> PickledUserSigning { + let pickle = self.inner.pickle(pickle_key).await; + let public_key = self.public_key.clone().into(); + PickledUserSigning { pickle, public_key } + } + + async fn sign_user(&self, _: &UserIdentity) -> BTreeMap> { + todo!(); + } + + fn from_pickle(pickle: PickledUserSigning, pickle_key: &[u8]) -> Result { + let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; + + Ok(Self { + inner, + public_key: pickle.public_key.into(), + }) + } +} + +impl SelfSigning { + async fn pickle(&self, pickle_key: &[u8]) -> PickledSelfSigning { + let pickle = self.inner.pickle(pickle_key).await; + let public_key = self.public_key.clone().into(); + PickledSelfSigning { pickle, public_key } + } + + async fn sign_device_raw(&self, value: Value) -> Result { + self.inner.sign_json(value).await + } + + async fn sign_device(&self, device_keys: &mut DeviceKeys) -> Result<(), SignatureError> { + let json_device = serde_json::to_value(&device_keys)?; + let signature = self.sign_device_raw(json_device).await?; + + device_keys + .signatures + .entry(self.public_key.user_id().to_owned()) + .or_insert_with(BTreeMap::new) + .insert( + DeviceKeyId::from_parts( + DeviceKeyAlgorithm::Ed25519, + self.inner.public_key.as_str().into(), + ), + signature.0, + ); + + Ok(()) + } + + fn from_pickle(pickle: PickledSelfSigning, pickle_key: &[u8]) -> Result { + let inner = Signing::from_pickle(pickle.pickle, pickle_key)?; + + Ok(Self { + inner, + public_key: pickle.public_key.into(), + }) + } +} + +#[derive(Clone, PartialEq, Debug)] +struct SelfSigning { + inner: Signing, + public_key: SelfSigningPubkey, +} + +#[derive(Clone, PartialEq, Debug)] +struct UserSigning { + inner: Signing, + public_key: UserSigningPubkey, +} + +#[derive(Clone, Debug)] +pub struct PrivateCrossSigningIdentity { + user_id: Arc, + shared: Arc, + master_key: Arc>>, + user_signing_key: Arc>>, + self_signing_key: Arc>>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PickledCrossSigningIdentity { + pub user_id: UserId, + pub shared: bool, + pub pickle: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PickledSignings { + master_key: Option, + user_signing_key: Option, + self_signing_key: Option, +} + +#[derive(Debug, Clone)] +pub struct Signature(String); + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PickledSigning(String); + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct InnerPickle { + version: u8, + nonce: String, + ciphertext: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PublicSigningKey(Arc); + +impl Signature { + fn as_str(&self) -> &str { + &self.0 + } +} + +impl PickledSigning { + fn as_str(&self) -> &str { + &self.0 + } +} + +impl PublicSigningKey { + fn as_str(&self) -> &str { + &self.0 + } + + #[allow(clippy::inherent_to_string)] + fn to_string(&self) -> String { + self.0.to_string() + } +} + +impl Signing { + fn new() -> Self { + let seed = OlmPkSigning::generate_seed(); + Self::from_seed(seed) + } + + fn from_seed(seed: Vec) -> Self { + let inner = OlmPkSigning::new(seed.clone()).expect("Unable to create pk signing object"); + let public_key = PublicSigningKey(Arc::new(inner.public_key().to_owned())); + + Signing { + inner: Arc::new(Mutex::new(inner)), + seed: Arc::new(Zeroizing::from(seed)), + public_key, + } + } + + fn from_pickle(pickle: PickledSigning, pickle_key: &[u8]) -> Result { + let pickled: InnerPickle = serde_json::from_str(pickle.as_str())?; + + let key = GenericArray::from_slice(pickle_key); + let cipher = Aes256Gcm::new(key); + + let nonce = decode(pickled.nonce)?; + let nonce = GenericArray::from_slice(&nonce); + let ciphertext = &decode(pickled.ciphertext)?; + + let seed = cipher + .decrypt(&nonce, ciphertext.as_slice()) + .map_err(|e| SigningError::Decryption(e.to_string()))?; + + Ok(Self::from_seed(seed)) + } + + async fn pickle(&self, pickle_key: &[u8]) -> PickledSigning { + let key = GenericArray::from_slice(pickle_key); + let cipher = Aes256Gcm::new(key); + + let mut nonce = vec![0u8; NONCE_SIZE]; + getrandom(&mut nonce).expect("Can't generate nonce to pickle the signing object"); + let nonce = GenericArray::from_slice(nonce.as_slice()); + + let ciphertext = cipher + .encrypt(nonce, self.seed.as_slice()) + .expect("Can't encrypt signing pickle"); + + let ciphertext = encode(ciphertext); + + let pickle = InnerPickle { + version: 1, + nonce: encode(nonce.as_slice()), + ciphertext, + }; + + PickledSigning(serde_json::to_string(&pickle).expect("Can't encode pickled signing")) + } + + fn public_key(&self) -> &PublicSigningKey { + &self.public_key + } + + fn cross_signing_key(&self, user_id: UserId, usage: KeyUsage) -> CrossSigningKey { + let mut keys = BTreeMap::new(); + + keys.insert( + format!("ed25519:{}", self.public_key().as_str()), + self.public_key().to_string(), + ); + + CrossSigningKey { + user_id, + usage: vec![usage], + keys, + signatures: BTreeMap::new(), + } + } + + async fn verify(&self, message: &str, signature: &Signature) -> Result { + let utility = OlmUtility::new(); + utility.ed25519_verify(self.public_key.as_str(), message, signature.as_str()) + } + + async fn sign_json(&self, mut json: Value) -> Result { + let json_object = json.as_object_mut().ok_or(SignatureError::NotAnObject)?; + let _ = json_object.remove("signatures"); + let canonical_json = cjson::to_string(json_object)?; + Ok(self.sign(&canonical_json).await) + } + + async fn sign(&self, message: &str) -> Signature { + Signature(self.inner.lock().await.sign(message)) + } +} + +impl PrivateCrossSigningIdentity { + pub fn user_id(&self) -> &UserId { + &self.user_id + } + + pub async fn is_empty(&self) -> bool { + let has_master = self.master_key.lock().await.is_some(); + let has_user = self.user_signing_key.lock().await.is_some(); + let has_self = self.self_signing_key.lock().await.is_some(); + + !(has_master && has_user && has_self) + } + + pub(crate) fn empty(user_id: UserId) -> Self { + Self { + user_id: Arc::new(user_id), + shared: Arc::new(AtomicBool::new(false)), + master_key: Arc::new(Mutex::new(None)), + self_signing_key: Arc::new(Mutex::new(None)), + user_signing_key: Arc::new(Mutex::new(None)), + } + } + + pub(crate) async fn sign_device( + &self, + mut device_keys: DeviceKeys, + ) -> Result { + self.self_signing_key + .lock() + .await + .as_ref() + .ok_or(SignatureError::MissingSigningKey)? + .sign_device(&mut device_keys) + .await?; + + let mut signed_keys = BTreeMap::new(); + signed_keys + .entry((&*self.user_id).to_owned()) + .or_insert_with(BTreeMap::new) + .insert( + device_keys.device_id.to_string(), + serde_json::to_value(device_keys)?, + ); + + Ok(SignatureUploadRequest { signed_keys }) + } + + pub(crate) async fn new_with_account( + account: &ReadOnlyAccount, + ) -> (Self, SignatureUploadRequest) { + let master = Signing::new(); + + let mut public_key = + master.cross_signing_key(account.user_id().to_owned(), KeyUsage::Master); + let signature = account + .sign_json( + serde_json::to_value(&public_key) + .expect("Can't convert own public master key to json"), + ) + .await + .expect("Can't sign own public master key"); + public_key + .signatures + .entry(account.user_id().to_owned()) + .or_insert_with(BTreeMap::new) + .insert(format!("ed25519:{}", account.device_id()), signature); + + let master = MasterSigning { + inner: master, + public_key: public_key.into(), + }; + + let identity = Self::new_helper(account.user_id(), master).await; + let device_keys = account.unsigned_device_keys(); + let request = identity + .sign_device(device_keys) + .await + .expect("Can't sign own device with new cross signign keys"); + + (identity, request) + } + + async fn new_helper(user_id: &UserId, master: MasterSigning) -> Self { + let user = Signing::new(); + let mut public_key = user.cross_signing_key(user_id.to_owned(), KeyUsage::UserSigning); + master.sign_subkey(&mut public_key).await; + + let user = UserSigning { + inner: user, + public_key: public_key.into(), + }; + + let self_signing = Signing::new(); + let mut public_key = + self_signing.cross_signing_key(user_id.to_owned(), KeyUsage::SelfSigning); + master.sign_subkey(&mut public_key).await; + + let self_signing = SelfSigning { + inner: self_signing, + public_key: public_key.into(), + }; + + Self { + user_id: Arc::new(user_id.to_owned()), + shared: Arc::new(AtomicBool::new(false)), + master_key: Arc::new(Mutex::new(Some(master))), + self_signing_key: Arc::new(Mutex::new(Some(self_signing))), + user_signing_key: Arc::new(Mutex::new(Some(user))), + } + } + + pub(crate) async fn new(user_id: UserId) -> Self { + let master = Signing::new(); + + let public_key = master.cross_signing_key(user_id.clone(), KeyUsage::Master); + let master = MasterSigning { + inner: master, + public_key: public_key.into(), + }; + + let user = Signing::new(); + let mut public_key = user.cross_signing_key(user_id.clone(), KeyUsage::UserSigning); + master.sign_subkey(&mut public_key).await; + + let user = UserSigning { + inner: user, + public_key: public_key.into(), + }; + + let self_signing = Signing::new(); + let mut public_key = self_signing.cross_signing_key(user_id.clone(), KeyUsage::SelfSigning); + master.sign_subkey(&mut public_key).await; + + let self_signing = SelfSigning { + inner: self_signing, + public_key: public_key.into(), + }; + + Self { + user_id: Arc::new(user_id), + shared: Arc::new(AtomicBool::new(false)), + master_key: Arc::new(Mutex::new(Some(master))), + self_signing_key: Arc::new(Mutex::new(Some(self_signing))), + user_signing_key: Arc::new(Mutex::new(Some(user))), + } + } + + pub fn mark_as_shared(&self) { + self.shared.store(true, Ordering::SeqCst) + } + + pub fn shared(&self) -> bool { + self.shared.load(Ordering::SeqCst) + } + + pub async fn pickle( + &self, + pickle_key: &[u8], + ) -> Result { + let master_key = if let Some(m) = self.master_key.lock().await.as_ref() { + Some(m.pickle(pickle_key).await) + } else { + None + }; + + let self_signing_key = if let Some(m) = self.self_signing_key.lock().await.as_ref() { + Some(m.pickle(pickle_key).await) + } else { + None + }; + + let user_signing_key = if let Some(m) = self.user_signing_key.lock().await.as_ref() { + Some(m.pickle(pickle_key).await) + } else { + None + }; + + let pickle = PickledSignings { + master_key, + user_signing_key, + self_signing_key, + }; + + let pickle = serde_json::to_string(&pickle)?; + + Ok(PickledCrossSigningIdentity { + user_id: self.user_id.as_ref().to_owned(), + shared: self.shared(), + pickle, + }) + } + + /// Restore the private cross signing identity from a pickle. + /// + /// # Panic + /// + /// Panics if the pickle_key isn't 32 bytes long. + pub async fn from_pickle( + pickle: PickledCrossSigningIdentity, + pickle_key: &[u8], + ) -> Result { + let signings: PickledSignings = serde_json::from_str(&pickle.pickle)?; + + let master = if let Some(m) = signings.master_key { + Some(MasterSigning::from_pickle(m, pickle_key)?) + } else { + None + }; + + let self_signing = if let Some(s) = signings.self_signing_key { + Some(SelfSigning::from_pickle(s, pickle_key)?) + } else { + None + }; + + let user_signing = if let Some(u) = signings.user_signing_key { + Some(UserSigning::from_pickle(u, pickle_key)?) + } else { + None + }; + + Ok(Self { + user_id: Arc::new(pickle.user_id), + shared: Arc::new(AtomicBool::from(pickle.shared)), + master_key: Arc::new(Mutex::new(master)), + self_signing_key: Arc::new(Mutex::new(self_signing)), + user_signing_key: Arc::new(Mutex::new(user_signing)), + }) + } + + pub(crate) async fn as_upload_request(&self) -> UploadSigningKeysRequest { + let master_key = self + .master_key + .lock() + .await + .as_ref() + .cloned() + .map(|k| k.public_key.into()); + + let user_signing_key = self + .user_signing_key + .lock() + .await + .as_ref() + .cloned() + .map(|k| k.public_key.into()); + + let self_signing_key = self + .self_signing_key + .lock() + .await + .as_ref() + .cloned() + .map(|k| k.public_key.into()); + + UploadSigningKeysRequest { + master_key, + user_signing_key, + self_signing_key, + } + } +} + +#[cfg(test)] +mod test { + use crate::olm::ReadOnlyAccount; + + use super::{PrivateCrossSigningIdentity, Signing}; + + use matrix_sdk_common::identifiers::{user_id, UserId}; + use matrix_sdk_test::async_test; + + fn user_id() -> UserId { + user_id!("@example:localhost") + } + + fn pickle_key() -> &'static [u8] { + &[0u8; 32] + } + + #[test] + fn signing_creation() { + let signing = Signing::new(); + assert!(!signing.public_key().as_str().is_empty()); + } + + #[async_test] + async fn signature_verification() { + let signing = Signing::new(); + + let message = "Hello world"; + + let signature = signing.sign(message).await; + assert!(signing.verify(message, &signature).await.is_ok()); + } + + #[async_test] + async fn pickling_signing() { + let signing = Signing::new(); + let pickled = signing.pickle(pickle_key()).await; + + let unpickled = Signing::from_pickle(pickled, pickle_key()).unwrap(); + + assert_eq!(signing.public_key(), unpickled.public_key()); + } + + #[async_test] + async fn private_identity_creation() { + let identity = PrivateCrossSigningIdentity::new(user_id()).await; + + let master_key = identity.master_key.lock().await; + let master_key = master_key.as_ref().unwrap(); + + assert!(master_key + .public_key + .verify_subkey( + &identity + .self_signing_key + .lock() + .await + .as_ref() + .unwrap() + .public_key, + ) + .is_ok()); + + assert!(master_key + .public_key + .verify_subkey( + &identity + .user_signing_key + .lock() + .await + .as_ref() + .unwrap() + .public_key, + ) + .is_ok()); + } + + #[async_test] + async fn identity_pickling() { + let identity = PrivateCrossSigningIdentity::new(user_id()).await; + + let pickled = identity.pickle(pickle_key()).await.unwrap(); + + let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled, pickle_key()) + .await + .unwrap(); + + assert_eq!(identity.user_id, unpickled.user_id); + assert_eq!( + &*identity.master_key.lock().await, + &*unpickled.master_key.lock().await + ); + assert_eq!( + &*identity.user_signing_key.lock().await, + &*unpickled.user_signing_key.lock().await + ); + assert_eq!( + &*identity.self_signing_key.lock().await, + &*unpickled.self_signing_key.lock().await + ); + } + + #[async_test] + async fn private_identity_signed_by_accound() { + let account = ReadOnlyAccount::new(&user_id(), "DEVICEID".into()); + let (identity, _) = PrivateCrossSigningIdentity::new_with_account(&account).await; + let master = identity.master_key.lock().await; + let master = master.as_ref().unwrap(); + + assert!(!master.public_key.signatures().is_empty()); + } +} diff --git a/matrix_sdk_crypto/src/requests.rs b/matrix_sdk_crypto/src/requests.rs index 3a6ba4d97..5b70c7447 100644 --- a/matrix_sdk_crypto/src/requests.rs +++ b/matrix_sdk_crypto/src/requests.rs @@ -20,6 +20,8 @@ use matrix_sdk_common::{ claim_keys::Response as KeysClaimResponse, get_keys::Response as KeysQueryResponse, upload_keys::{Request as KeysUploadRequest, Response as KeysUploadResponse}, + upload_signing_keys::Response as SigningKeysUploadResponse, + CrossSigningKey, }, to_device::{send_event_to_device::Response as ToDeviceResponse, DeviceIdOrAllDevices}, }, @@ -56,6 +58,21 @@ impl ToDeviceRequest { } } +/// Request that will publish a cross signing identity. +/// +/// This uploads the public cross signing key triplet. +#[derive(Debug, Clone)] +pub struct UploadSigningKeysRequest { + /// The user's master key. + pub master_key: Option, + /// The user's self-signing key. Must be signed with the accompanied master, or by the + /// user's most recently uploaded master key if no master key is included in the request. + pub self_signing_key: Option, + /// The user's user-signing key. Must be signed with the accompanied master, or by the + /// user's most recently uploaded master key if no master key is included in the request. + pub user_signing_key: Option, +} + /// Customized version of `ruma_client_api::r0::keys::get_keys::Request`, without any references. #[derive(Clone, Debug)] pub struct KeysQueryRequest { @@ -141,6 +158,9 @@ pub enum IncomingResponse<'a> { /// The key claiming requests, giving us new one-time keys of other users so /// new Olm sessions can be created. KeysClaim(&'a KeysClaimResponse), + /// The cross signing keys upload response, marking our private cross + /// signing identity as shared. + SigningKeysUpload(&'a SigningKeysUploadResponse), } impl<'a> From<&'a KeysUploadResponse> for IncomingResponse<'a> { diff --git a/matrix_sdk_crypto/src/session_manager.rs b/matrix_sdk_crypto/src/session_manager.rs deleted file mode 100644 index e0dc57bf2..000000000 --- a/matrix_sdk_crypto/src/session_manager.rs +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{collections::BTreeMap, time::Duration}; - -use matrix_sdk_common::{ - api::r0::keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, - assign, - identifiers::{DeviceKeyAlgorithm, UserId}, - uuid::Uuid, -}; -use tracing::{error, info, warn}; - -use crate::{error::OlmResult, key_request::KeyRequestMachine, olm::Account, store::Store}; - -#[derive(Debug, Clone)] -pub(crate) struct SessionManager { - account: Account, - store: Store, - key_request_machine: KeyRequestMachine, -} - -impl SessionManager { - const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10); - - pub fn new(account: Account, key_request_machine: KeyRequestMachine, store: Store) -> Self { - Self { - account, - store, - key_request_machine, - } - } - - /// Get the a key claiming request for the user/device pairs that we are - /// missing Olm sessions for. - /// - /// Returns None if no key claiming request needs to be sent out. - /// - /// Sessions need to be established between devices so group sessions for a - /// room can be shared with them. - /// - /// This should be called every time a group session needs to be shared as - /// well as between sync calls. After a sync some devices may request room - /// keys without us having a valid Olm session with them, making it - /// impossible to server the room key request, thus it's necessary to check - /// for missing sessions between sync as well. - /// - /// **Note**: Care should be taken that only one such request at a time is - /// in flight, e.g. using a lock. - /// - /// The response of a successful key claiming requests needs to be passed to - /// the `OlmMachine` with the [`receive_keys_claim_response`]. - /// - /// # Arguments - /// - /// `users` - The list of users that we should check if we lack a session - /// with one of their devices. This can be an empty iterator when calling - /// this method between sync requests. - /// - /// [`receive_keys_claim_response`]: #method.receive_keys_claim_response - pub async fn get_missing_sessions( - &self, - users: &mut impl Iterator, - ) -> OlmResult> { - let mut missing = BTreeMap::new(); - - // Add the list of devices that the user wishes to establish sessions - // right now. - for user_id in users { - let user_devices = self.store.get_user_devices(user_id).await?; - - for device in user_devices.devices() { - let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) { - k - } else { - continue; - }; - - let sessions = self.store.get_sessions(sender_key).await?; - - let is_missing = if let Some(sessions) = sessions { - sessions.lock().await.is_empty() - } else { - true - }; - - if is_missing { - missing - .entry(user_id.to_owned()) - .or_insert_with(BTreeMap::new) - .insert( - device.device_id().into(), - DeviceKeyAlgorithm::SignedCurve25519, - ); - } - } - } - - // Add the list of sessions that for some reason automatically need to - // create an Olm session. - for item in self.key_request_machine.users_for_key_claim().iter() { - let user = item.key(); - - for device_id in item.value().iter() { - missing - .entry(user.to_owned()) - .or_insert_with(BTreeMap::new) - .insert(device_id.to_owned(), DeviceKeyAlgorithm::SignedCurve25519); - } - } - - if missing.is_empty() { - Ok(None) - } else { - Ok(Some(( - Uuid::new_v4(), - assign!(KeysClaimRequest::new(missing), { - timeout: Some(Self::KEY_CLAIM_TIMEOUT), - }), - ))) - } - } - - /// Receive a successful key claim response and create new Olm sessions with - /// the claimed keys. - /// - /// # Arguments - /// - /// * `response` - The response containing the claimed one-time keys. - pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { - // TODO log the failures here - - for (user_id, user_devices) in &response.one_time_keys { - for (device_id, key_map) in user_devices { - let device = match self.store.get_readonly_device(&user_id, device_id).await { - Ok(Some(d)) => d, - Ok(None) => { - warn!( - "Tried to create an Olm session for {} {}, but the device is unknown", - user_id, device_id - ); - continue; - } - Err(e) => { - warn!( - "Tried to create an Olm session for {} {}, but \ - can't fetch the device from the store {:?}", - user_id, device_id, e - ); - continue; - } - }; - - info!("Creating outbound Session for {} {}", user_id, device_id); - - let session = match self.account.create_outbound_session(device, &key_map).await { - Ok(s) => s, - Err(e) => { - warn!("{:?}", e); - continue; - } - }; - - if let Err(e) = self.store.save_sessions(&[session]).await { - error!("Failed to store newly created Olm session {}", e); - continue; - } - - // TODO if this session was created because a previous one was - // wedged queue up a dummy event to be sent out. - self.key_request_machine.retry_keyshare(&user_id, device_id); - } - } - Ok(()) - } -} diff --git a/matrix_sdk_crypto/src/group_manager.rs b/matrix_sdk_crypto/src/session_manager/group_sessions.rs similarity index 92% rename from matrix_sdk_crypto/src/group_manager.rs rename to matrix_sdk_crypto/src/session_manager/group_sessions.rs index d20f8f8e7..374610fd0 100644 --- a/matrix_sdk_crypto/src/group_manager.rs +++ b/matrix_sdk_crypto/src/session_manager/group_sessions.rs @@ -28,8 +28,8 @@ use tracing::{debug, info}; use crate::{ error::{EventError, MegolmResult, OlmResult}, - olm::{Account, OutboundGroupSession}, - store::Store, + olm::{Account, InboundGroupSession, OutboundGroupSession}, + store::{Changes, Store}, Device, EncryptionSettings, OlmError, ToDeviceRequest, }; @@ -140,19 +140,17 @@ impl GroupSessionManager { &self, room_id: &RoomId, settings: EncryptionSettings, - ) -> OlmResult<()> { + ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> { let (outbound, inbound) = self .account .create_group_session_pair(room_id, settings) .await .map_err(|_| EventError::UnsupportedAlgorithm)?; - let _ = self.store.save_inbound_group_sessions(&[inbound]).await?; - let _ = self .outbound_group_sessions - .insert(room_id.to_owned(), outbound); - Ok(()) + .insert(room_id.to_owned(), outbound.clone()); + Ok((outbound, inbound)) } /// Get to-device requests to share a group session with users in a room. @@ -169,13 +167,12 @@ impl GroupSessionManager { users: impl Iterator, encryption_settings: impl Into, ) -> OlmResult>> { - self.create_outbound_group_session(room_id, encryption_settings.into()) - .await?; - let session = self.outbound_group_sessions.get(room_id).unwrap(); + let mut changes = Changes::default(); - if session.shared() { - panic!("Session is already shared"); - } + let (session, inbound_session) = self + .create_outbound_group_session(room_id, encryption_settings.into()) + .await?; + changes.inbound_group_sessions.push(inbound_session); let mut devices: Vec = Vec::new(); @@ -196,7 +193,7 @@ impl GroupSessionManager { .encrypt(EventType::RoomKey, key_content.clone()) .await; - let encrypted = match encrypted { + let (used_session, encrypted) = match encrypted { Ok(c) => c, Err(OlmError::MissingSession) | Err(OlmError::EventError(EventError::MissingSenderKey)) => { @@ -205,6 +202,8 @@ impl GroupSessionManager { Err(e) => return Err(e), }; + changes.sessions.push(used_session); + messages .entry(device.user_id().clone()) .or_insert_with(BTreeMap::new) @@ -237,6 +236,8 @@ impl GroupSessionManager { session.mark_as_shared(); } + self.store.save_changes(changes).await?; + Ok(requests) } } diff --git a/matrix_sdk_crypto/src/session_manager/mod.rs b/matrix_sdk_crypto/src/session_manager/mod.rs new file mode 100644 index 000000000..7750262ed --- /dev/null +++ b/matrix_sdk_crypto/src/session_manager/mod.rs @@ -0,0 +1,19 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod group_sessions; +mod sessions; + +pub(crate) use group_sessions::GroupSessionManager; +pub(crate) use sessions::SessionManager; diff --git a/matrix_sdk_crypto/src/session_manager/sessions.rs b/matrix_sdk_crypto/src/session_manager/sessions.rs new file mode 100644 index 000000000..29af3b6f2 --- /dev/null +++ b/matrix_sdk_crypto/src/session_manager/sessions.rs @@ -0,0 +1,498 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::BTreeMap, sync::Arc, time::Duration}; + +use dashmap::{DashMap, DashSet}; +use matrix_sdk_common::{ + api::r0::{ + keys::claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse}, + to_device::DeviceIdOrAllDevices, + }, + assign, + events::EventType, + identifiers::{DeviceId, DeviceIdBox, DeviceKeyAlgorithm, UserId}, + uuid::Uuid, +}; +use serde_json::{json, value::to_raw_value}; +use tracing::{error, info, warn}; + +use crate::{ + error::OlmResult, + key_request::KeyRequestMachine, + olm::Account, + requests::{OutgoingRequest, ToDeviceRequest}, + store::{Changes, Result as StoreResult, Store}, + ReadOnlyDevice, +}; + +#[derive(Debug, Clone)] +pub(crate) struct SessionManager { + account: Account, + store: Store, + /// A map of user/devices that we need to automatically claim keys for. + /// Submodules can insert user/device pairs into this map and the + /// user/device paris will be added to the list of users when + /// [`get_missing_sessions`](#method.get_missing_sessions) is called. + users_for_key_claim: Arc>>, + wedged_devices: Arc>>, + key_request_machine: KeyRequestMachine, + outgoing_to_device_requests: Arc>, +} + +impl SessionManager { + const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10); + const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60); + + pub fn new( + account: Account, + users_for_key_claim: Arc>>, + key_request_machine: KeyRequestMachine, + store: Store, + ) -> Self { + Self { + account, + store, + key_request_machine, + users_for_key_claim, + wedged_devices: Arc::new(DashMap::new()), + outgoing_to_device_requests: Arc::new(DashMap::new()), + } + } + + /// Mark the outgoing request as sent. + pub fn mark_outgoing_request_as_sent(&self, id: &Uuid) { + self.outgoing_to_device_requests.remove(id); + } + + pub async fn mark_device_as_wedged(&self, sender: &UserId, curve_key: &str) -> StoreResult<()> { + if let Some(device) = self + .store + .get_device_from_curve_key(sender, curve_key) + .await? + { + let sessions = device.get_sessions().await?; + + if let Some(sessions) = sessions { + let mut sessions = sessions.lock().await; + sessions.sort_by_key(|s| s.creation_time.clone()); + + let session = sessions.get(0); + + if let Some(session) = session { + if session.creation_time.elapsed() > Self::UNWEDGING_INTERVAL { + self.users_for_key_claim + .entry(device.user_id().clone()) + .or_insert_with(DashSet::new) + .insert(device.device_id().into()); + self.wedged_devices + .entry(device.user_id().to_owned()) + .or_insert_with(DashSet::new) + .insert(device.device_id().into()); + } + } + } + } + + Ok(()) + } + + #[allow(dead_code)] + pub fn is_device_wedged(&self, device: &ReadOnlyDevice) -> bool { + self.wedged_devices + .get(device.user_id()) + .map(|d| d.contains(device.device_id())) + .unwrap_or(false) + } + + /// Check if the session was created to unwedge a Device. + /// + /// If the device was wedged this will queue up a dummy to-device message. + async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> { + if self + .wedged_devices + .get(user_id) + .map(|d| d.remove(device_id)) + .flatten() + .is_some() + { + if let Some(device) = self.store.get_device(user_id, device_id).await? { + let (_, content) = device.encrypt(EventType::Dummy, json!({})).await?; + let id = Uuid::new_v4(); + let mut messages = BTreeMap::new(); + + messages + .entry(device.user_id().to_owned()) + .or_insert_with(BTreeMap::new) + .insert( + DeviceIdOrAllDevices::DeviceId(device.device_id().into()), + to_raw_value(&content)?, + ); + + let request = OutgoingRequest { + request_id: id, + request: Arc::new( + ToDeviceRequest { + event_type: EventType::RoomEncrypted, + txn_id: id, + messages, + } + .into(), + ), + }; + + self.outgoing_to_device_requests.insert(id, request); + } + } + + Ok(()) + } + + /// Get the a key claiming request for the user/device pairs that we are + /// missing Olm sessions for. + /// + /// Returns None if no key claiming request needs to be sent out. + /// + /// Sessions need to be established between devices so group sessions for a + /// room can be shared with them. + /// + /// This should be called every time a group session needs to be shared as + /// well as between sync calls. After a sync some devices may request room + /// keys without us having a valid Olm session with them, making it + /// impossible to server the room key request, thus it's necessary to check + /// for missing sessions between sync as well. + /// + /// **Note**: Care should be taken that only one such request at a time is + /// in flight, e.g. using a lock. + /// + /// The response of a successful key claiming requests needs to be passed to + /// the `OlmMachine` with the [`receive_keys_claim_response`]. + /// + /// # Arguments + /// + /// `users` - The list of users that we should check if we lack a session + /// with one of their devices. This can be an empty iterator when calling + /// this method between sync requests. + /// + /// [`receive_keys_claim_response`]: #method.receive_keys_claim_response + pub async fn get_missing_sessions( + &self, + users: &mut impl Iterator, + ) -> OlmResult> { + let mut missing = BTreeMap::new(); + + // Add the list of devices that the user wishes to establish sessions + // right now. + for user_id in users { + let user_devices = self.store.get_user_devices(user_id).await?; + + for device in user_devices.devices() { + let sender_key = if let Some(k) = device.get_key(DeviceKeyAlgorithm::Curve25519) { + k + } else { + continue; + }; + + let sessions = self.store.get_sessions(sender_key).await?; + + let is_missing = if let Some(sessions) = sessions { + sessions.lock().await.is_empty() + } else { + true + }; + + if is_missing { + missing + .entry(user_id.to_owned()) + .or_insert_with(BTreeMap::new) + .insert( + device.device_id().into(), + DeviceKeyAlgorithm::SignedCurve25519, + ); + } + } + } + + // Add the list of sessions that for some reason automatically need to + // create an Olm session. + for item in self.users_for_key_claim.iter() { + let user = item.key(); + + for device_id in item.value().iter() { + missing + .entry(user.to_owned()) + .or_insert_with(BTreeMap::new) + .insert(device_id.to_owned(), DeviceKeyAlgorithm::SignedCurve25519); + } + } + + if missing.is_empty() { + Ok(None) + } else { + Ok(Some(( + Uuid::new_v4(), + assign!(KeysClaimRequest::new(missing), { + timeout: Some(Self::KEY_CLAIM_TIMEOUT), + }), + ))) + } + } + + /// Receive a successful key claim response and create new Olm sessions with + /// the claimed keys. + /// + /// # Arguments + /// + /// * `response` - The response containing the claimed one-time keys. + pub async fn receive_keys_claim_response(&self, response: &KeysClaimResponse) -> OlmResult<()> { + // TODO log the failures here + + let mut changes = Changes::default(); + + for (user_id, user_devices) in &response.one_time_keys { + for (device_id, key_map) in user_devices { + let device = match self.store.get_readonly_device(&user_id, device_id).await { + Ok(Some(d)) => d, + Ok(None) => { + warn!( + "Tried to create an Olm session for {} {}, but the device is unknown", + user_id, device_id + ); + continue; + } + Err(e) => { + warn!( + "Tried to create an Olm session for {} {}, but \ + can't fetch the device from the store {:?}", + user_id, device_id, e + ); + continue; + } + }; + + info!("Creating outbound Session for {} {}", user_id, device_id); + + let session = match self.account.create_outbound_session(device, &key_map).await { + Ok(s) => s, + Err(e) => { + warn!("Error creating new outbound session {:?}", e); + continue; + } + }; + + changes.sessions.push(session); + + self.key_request_machine.retry_keyshare(&user_id, device_id); + + if let Err(e) = self.check_if_unwedged(&user_id, device_id).await { + error!( + "Error while treating an unwedged device {} {} {:?}", + user_id, device_id, e + ); + } + } + } + + Ok(self.store.save_changes(changes).await?) + } +} + +#[cfg(test)] +mod test { + use dashmap::DashMap; + use matrix_sdk_common::locks::Mutex; + use std::{collections::BTreeMap, sync::Arc}; + + use matrix_sdk_common::{ + api::r0::keys::claim_keys::Response as KeyClaimResponse, + identifiers::{user_id, DeviceIdBox, DeviceKeyAlgorithm, UserId}, + instant::{Duration, Instant}, + }; + use matrix_sdk_test::async_test; + + use super::SessionManager; + use crate::{ + identities::ReadOnlyDevice, + key_request::KeyRequestMachine, + olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + store::{CryptoStore, MemoryStore, Store}, + verification::VerificationMachine, + }; + + fn user_id() -> UserId { + user_id!("@example:localhost") + } + + fn device_id() -> DeviceIdBox { + "DEVICEID".into() + } + + fn bob_account() -> ReadOnlyAccount { + ReadOnlyAccount::new(&user_id!("@bob:localhost"), "BOBDEVICE".into()) + } + + async fn session_manager() -> SessionManager { + let user_id = user_id(); + let device_id = device_id(); + + let outbound_sessions = Arc::new(DashMap::new()); + let users_for_key_claim = Arc::new(DashMap::new()); + let account = ReadOnlyAccount::new(&user_id, &device_id); + let store: Arc> = Arc::new(Box::new(MemoryStore::new())); + store.save_account(account.clone()).await.unwrap(); + let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty( + user_id.clone(), + ))); + let verification = VerificationMachine::new(account.clone(), identity, store.clone()); + + let user_id = Arc::new(user_id); + let device_id = Arc::new(device_id); + + let store = Store::new(user_id.clone(), store, verification); + + let account = Account { + inner: account, + store: store.clone(), + }; + + let key_request = KeyRequestMachine::new( + user_id, + device_id, + store.clone(), + outbound_sessions, + users_for_key_claim.clone(), + ); + + SessionManager::new(account, users_for_key_claim, key_request, store) + } + + #[async_test] + async fn session_creation() { + let manager = session_manager().await; + let bob = bob_account(); + + let bob_device = ReadOnlyDevice::from_account(&bob).await; + + manager.store.save_devices(&[bob_device]).await.unwrap(); + + let (_, request) = manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .unwrap(); + + assert!(request.one_time_keys.contains_key(bob.user_id())); + + bob.generate_one_time_keys_helper(1).await; + let one_time = bob.signed_one_time_keys_helper().await.unwrap(); + bob.mark_keys_as_published().await; + + let mut one_time_keys = BTreeMap::new(); + one_time_keys + .entry(bob.user_id().clone()) + .or_insert_with(BTreeMap::new) + .insert(bob.device_id().into(), one_time); + + let response = KeyClaimResponse { + failures: BTreeMap::new(), + one_time_keys, + }; + + manager + .receive_keys_claim_response(&response) + .await + .unwrap(); + + assert!(manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .is_none()); + } + + // This test doesn't run on macos because we're modifying the session + // creation time so we can get around the UNWEDGING_INTERVAL. + #[async_test] + #[cfg(not(target_os = "macos"))] + async fn session_unwedging() { + let manager = session_manager().await; + let bob = bob_account(); + let (_, mut session) = bob.create_session_for(&manager.account).await; + + let bob_device = ReadOnlyDevice::from_account(&bob).await; + session.creation_time = Arc::new(Instant::now() - Duration::from_secs(3601)); + + manager + .store + .save_devices(&[bob_device.clone()]) + .await + .unwrap(); + manager.store.save_sessions(&[session]).await.unwrap(); + + assert!(manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .is_none()); + + let curve_key = bob_device.get_key(DeviceKeyAlgorithm::Curve25519).unwrap(); + + assert!(!manager.users_for_key_claim.contains_key(bob.user_id())); + assert!(!manager.is_device_wedged(&bob_device)); + manager + .mark_device_as_wedged(bob_device.user_id(), &curve_key) + .await + .unwrap(); + assert!(manager.is_device_wedged(&bob_device)); + assert!(manager.users_for_key_claim.contains_key(bob.user_id())); + + let (_, request) = manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .unwrap(); + + assert!(request.one_time_keys.contains_key(bob.user_id())); + + bob.generate_one_time_keys_helper(1).await; + let one_time = bob.signed_one_time_keys_helper().await.unwrap(); + bob.mark_keys_as_published().await; + + let mut one_time_keys = BTreeMap::new(); + one_time_keys + .entry(bob.user_id().clone()) + .or_insert_with(BTreeMap::new) + .insert(bob.device_id().into(), one_time); + + let response = KeyClaimResponse { + failures: BTreeMap::new(), + one_time_keys, + }; + + assert!(manager.outgoing_to_device_requests.is_empty()); + + manager + .receive_keys_claim_response(&response) + .await + .unwrap(); + + assert!(!manager.is_device_wedged(&bob_device)); + assert!(manager + .get_missing_sessions(&mut [bob.user_id().clone()].iter()) + .await + .unwrap() + .is_none()); + assert!(!manager.outgoing_to_device_requests.is_empty()) + } +} diff --git a/matrix_sdk_crypto/src/store/caches.rs b/matrix_sdk_crypto/src/store/caches.rs index eb66198da..3b306d6bd 100644 --- a/matrix_sdk_crypto/src/store/caches.rs +++ b/matrix_sdk_crypto/src/store/caches.rs @@ -19,9 +19,9 @@ use std::{collections::HashMap, sync::Arc}; -use dashmap::{DashMap, ReadOnlyView}; +use dashmap::DashMap; use matrix_sdk_common::{ - identifiers::{DeviceId, RoomId, UserId}, + identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, }; @@ -145,29 +145,6 @@ pub struct DeviceStore { entries: Arc, ReadOnlyDevice>>>, } -/// A read only view over all devices belonging to a user. -#[derive(Debug)] -pub struct ReadOnlyUserDevices { - entries: ReadOnlyView, ReadOnlyDevice>, -} - -impl ReadOnlyUserDevices { - /// Get the specific device with the given device id. - pub fn get(&self, device_id: &DeviceId) -> Option { - self.entries.get(device_id).cloned() - } - - /// Iterator over all the device ids of the user devices. - pub fn keys(&self) -> impl Iterator { - self.entries.keys().map(|id| id.as_ref()) - } - - /// Iterator over all the devices of the user devices. - pub fn devices(&self) -> impl Iterator { - self.entries.values() - } -} - impl DeviceStore { /// Create a new empty device store. pub fn new() -> Self { @@ -206,15 +183,13 @@ impl DeviceStore { } /// Get a read-only view over all devices of the given user. - pub fn user_devices(&self, user_id: &UserId) -> ReadOnlyUserDevices { - ReadOnlyUserDevices { - entries: self - .entries - .entry(user_id.clone()) - .or_insert_with(DashMap::new) - .clone() - .into_read_only(), - } + pub fn user_devices(&self, user_id: &UserId) -> HashMap { + self.entries + .entry(user_id.clone()) + .or_insert_with(DashMap::new) + .iter() + .map(|i| (i.key().to_owned(), i.value().clone())) + .collect() } } @@ -305,12 +280,12 @@ mod test { let user_devices = store.user_devices(device.user_id()); - assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.devices().next().unwrap(), &device); + assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); + assert_eq!(user_devices.values().next().unwrap(), &device); let loaded_device = user_devices.get(device.device_id()).unwrap(); - assert_eq!(device, loaded_device); + assert_eq!(&device, loaded_device); store.remove(device.user_id(), device.device_id()); diff --git a/matrix_sdk_crypto/src/store/memorystore.rs b/matrix_sdk_crypto/src/store/memorystore.rs index d7f6a4dab..952522e57 100644 --- a/matrix_sdk_crypto/src/store/memorystore.rs +++ b/matrix_sdk_crypto/src/store/memorystore.rs @@ -12,20 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use dashmap::{DashMap, DashSet}; use matrix_sdk_common::{ - identifiers::{DeviceId, RoomId, UserId}, + identifiers::{DeviceId, DeviceIdBox, RoomId, UserId}, locks::Mutex, }; use matrix_sdk_common_macros::async_trait; use super::{ - caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore}, - CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, + caches::{DeviceStore, GroupSessionStore, SessionStore}, + Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, Session, +}; +use crate::{ + identities::{ReadOnlyDevice, UserIdentities}, + olm::PrivateCrossSigningIdentity, }; -use crate::identities::{ReadOnlyDevice, UserIdentities}; /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug, Clone)] @@ -58,6 +64,30 @@ impl MemoryStore { pub fn new() -> Self { Self::default() } + + pub(crate) async fn save_devices(&self, mut devices: Vec) { + for device in devices.drain(..) { + let _ = self.devices.add(device); + } + } + + async fn delete_devices(&self, mut devices: Vec) { + for device in devices.drain(..) { + let _ = self.devices.remove(device.user_id(), device.device_id()); + } + } + + async fn save_sessions(&self, mut sessions: Vec) { + for session in sessions.drain(..) { + let _ = self.sessions.add(session.clone()).await; + } + } + + async fn save_inbound_group_sessions(&self, mut sessions: Vec) { + for session in sessions.drain(..) { + self.inbound_group_sessions.add(session); + } + } } #[async_trait] @@ -70,9 +100,24 @@ impl CryptoStore for MemoryStore { Ok(()) } - async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { - for session in sessions { - let _ = self.sessions.add(session.clone()).await; + async fn save_changes(&self, mut changes: Changes) -> Result<()> { + self.save_sessions(changes.sessions).await; + self.save_inbound_group_sessions(changes.inbound_group_sessions) + .await; + + self.save_devices(changes.devices.new).await; + self.save_devices(changes.devices.changed).await; + self.delete_devices(changes.devices.deleted).await; + + for identity in changes + .identities + .new + .drain(..) + .chain(changes.identities.changed) + { + let _ = self + .identities + .insert(identity.user_id().to_owned(), identity.clone()); } Ok(()) @@ -82,14 +127,6 @@ impl CryptoStore for MemoryStore { Ok(self.sessions.get(sender_key)) } - async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { - for session in sessions { - self.inbound_group_sessions.add(session.clone()); - } - - Ok(()) - } - async fn get_inbound_group_session( &self, room_id: &RoomId, @@ -148,37 +185,18 @@ impl CryptoStore for MemoryStore { Ok(self.devices.get(user_id, device_id)) } - async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> { - let _ = self.devices.remove(device.user_id(), device.device_id()); - Ok(()) - } - - async fn get_user_devices(&self, user_id: &UserId) -> Result { + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result> { Ok(self.devices.user_devices(user_id)) } - async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { - for device in devices { - let _ = self.devices.add(device.clone()); - } - - Ok(()) - } - async fn get_user_identity(&self, user_id: &UserId) -> Result> { #[allow(clippy::map_clone)] Ok(self.identities.get(user_id).map(|i| i.clone())) } - async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()> { - for identity in identities { - let _ = self - .identities - .insert(identity.user_id().to_owned(), identity.clone()); - } - Ok(()) - } - async fn save_value(&self, key: String, value: String) -> Result<()> { self.values.insert(key, value); Ok(()) @@ -192,6 +210,14 @@ impl CryptoStore for MemoryStore { async fn get_value(&self, key: &str) -> Result> { Ok(self.values.get(key).map(|v| v.to_owned())) } + + async fn save_identity(&self, _: PrivateCrossSigningIdentity) -> Result<()> { + Ok(()) + } + + async fn load_identity(&self) -> Result> { + Ok(None) + } } #[cfg(test)] @@ -211,7 +237,7 @@ mod test { assert!(store.load_account().await.unwrap().is_none()); store.save_account(account).await.unwrap(); - store.save_sessions(&[session.clone()]).await.unwrap(); + store.save_sessions(vec![session.clone()]).await; let sessions = store .get_sessions(&session.sender_key) @@ -244,9 +270,8 @@ mod test { let store = MemoryStore::new(); let _ = store - .save_inbound_group_sessions(&[inbound.clone()]) - .await - .unwrap(); + .save_inbound_group_sessions(vec![inbound.clone()]) + .await; let loaded_session = store .get_inbound_group_session(&room_id, "test_key", outbound.session_id()) @@ -261,7 +286,7 @@ mod test { let device = get_device(); let store = MemoryStore::new(); - store.save_devices(&[device.clone()]).await.unwrap(); + store.save_devices(vec![device.clone()]).await; let loaded_device = store .get_device(device.user_id(), device.device_id()) @@ -273,14 +298,14 @@ mod test { let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); - assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.devices().next().unwrap(), &device); + assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); + assert_eq!(user_devices.values().next().unwrap(), &device); let loaded_device = user_devices.get(device.device_id()).unwrap(); - assert_eq!(device, loaded_device); + assert_eq!(&device, loaded_device); - store.delete_device(device.clone()).await.unwrap(); + store.delete_devices(vec![device.clone()]).await; assert!(store .get_device(device.user_id(), device.device_id()) .await diff --git a/matrix_sdk_crypto/src/store/mod.rs b/matrix_sdk_crypto/src/store/mod.rs index 83c2a0f98..b19e3b07f 100644 --- a/matrix_sdk_crypto/src/store/mod.rs +++ b/matrix_sdk_crypto/src/store/mod.rs @@ -39,23 +39,30 @@ pub mod caches; mod memorystore; +mod pickle_key; #[cfg(not(target_arch = "wasm32"))] #[cfg(feature = "sqlite_cryptostore")] pub(crate) mod sqlite; -use caches::ReadOnlyUserDevices; +use matrix_sdk_common::identifiers::DeviceIdBox; pub use memorystore::MemoryStore; +pub use pickle_key::{EncryptedPickleKey, PickleKey}; #[cfg(not(target_arch = "wasm32"))] #[cfg(feature = "sqlite_cryptostore")] pub use sqlite::SqliteStore; -use std::{collections::HashSet, fmt::Debug, io::Error as IoError, ops::Deref, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Debug, + io::Error as IoError, + ops::Deref, + sync::Arc, +}; use olm_rs::errors::{OlmAccountError, OlmGroupSessionError, OlmSessionError}; use serde::{Deserialize, Serialize}; use serde_json::Error as SerdeError; use thiserror::Error; -use url::ParseError; #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] #[cfg(not(target_arch = "wasm32"))] @@ -63,7 +70,9 @@ use url::ParseError; use sqlx::Error as SqlxError; use matrix_sdk_common::{ - identifiers::{DeviceId, Error as IdentifierValidationError, RoomId, UserId}, + identifiers::{ + DeviceId, DeviceKeyAlgorithm, Error as IdentifierValidationError, RoomId, UserId, + }, locks::Mutex, }; use matrix_sdk_common_macros::async_trait; @@ -73,7 +82,7 @@ use matrix_sdk_common_macros::send_sync; use crate::{ error::SessionUnpicklingError, identities::{Device, ReadOnlyDevice, UserDevices, UserIdentities}, - olm::{InboundGroupSession, ReadOnlyAccount, Session}, + olm::{InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, Session}, verification::VerificationMachine, }; @@ -93,6 +102,31 @@ pub(crate) struct Store { verification_machine: VerificationMachine, } +#[derive(Debug, Default)] +#[allow(missing_docs)] +pub struct Changes { + pub account: Option, + pub sessions: Vec, + pub inbound_group_sessions: Vec, + pub identities: IdentityChanges, + pub devices: DeviceChanges, +} + +#[derive(Debug, Clone, Default)] +#[allow(missing_docs)] +pub struct IdentityChanges { + pub new: Vec, + pub changed: Vec, +} + +#[derive(Debug, Clone, Default)] +#[allow(missing_docs)] +pub struct DeviceChanges { + pub new: Vec, + pub changed: Vec, + pub deleted: Vec, +} + impl Store { pub fn new( user_id: Arc, @@ -114,10 +148,61 @@ impl Store { self.inner.get_device(user_id, device_id).await } - pub async fn get_readonly_devices(&self, user_id: &UserId) -> Result { + pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { + let changes = Changes { + sessions: sessions.to_vec(), + ..Default::default() + }; + + self.save_changes(changes).await + } + + #[cfg(test)] + pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { + let changes = Changes { + devices: DeviceChanges { + changed: devices.to_vec(), + ..Default::default() + }, + ..Default::default() + }; + + self.save_changes(changes).await + } + + #[cfg(test)] + pub async fn save_inbound_group_sessions( + &self, + sessions: &[InboundGroupSession], + ) -> Result<()> { + let changes = Changes { + inbound_group_sessions: sessions.to_vec(), + ..Default::default() + }; + + self.save_changes(changes).await + } + + pub async fn get_readonly_devices( + &self, + user_id: &UserId, + ) -> Result> { self.inner.get_user_devices(user_id).await } + pub async fn get_device_from_curve_key( + &self, + user_id: &UserId, + curve_key: &str, + ) -> Result> { + self.get_user_devices(user_id).await.map(|d| { + d.devices().find(|d| { + d.get_key(DeviceKeyAlgorithm::Curve25519) + .map_or(false, |k| k == curve_key) + }) + }) + } + pub async fn get_user_devices(&self, user_id: &UserId) -> Result { let devices = self.inner.get_user_devices(user_id).await?; @@ -223,6 +308,10 @@ pub enum CryptoStoreError { #[error(transparent)] SessionUnpickling(#[from] SessionUnpicklingError), + /// Failed to decrypt an pickled object. + #[error("An object failed to be decrypted while unpickling")] + UnpicklingError, + /// A Matirx identifier failed to be validated. #[error(transparent)] IdentifierValidation(#[from] IdentifierValidationError), @@ -230,10 +319,6 @@ pub enum CryptoStoreError { /// The store failed to (de)serialize a data type. #[error(transparent)] Serialization(#[from] SerdeError), - - /// An error occurred while parsing an URL. - #[error(transparent)] - UrlParse(#[from] ParseError), } /// Trait abstracting a store that the `OlmMachine` uses to store cryptographic @@ -252,12 +337,14 @@ pub trait CryptoStore: Debug { /// * `account` - The account that should be stored. async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; - /// Save the given sessions in the store. - /// - /// # Arguments - /// - /// * `session` - The sessions that should be stored. - async fn save_sessions(&self, session: &[Session]) -> Result<()>; + /// TODO + async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()>; + + /// TODO + async fn load_identity(&self) -> Result>; + + /// TODO + async fn save_changes(&self, changes: Changes) -> Result<()>; /// Get all the sessions that belong to the given sender key. /// @@ -266,13 +353,6 @@ pub trait CryptoStore: Debug { /// * `sender_key` - The sender key that was used to establish the sessions. async fn get_sessions(&self, sender_key: &str) -> Result>>>>; - /// Save the given inbound group sessions in the store. - /// - /// # Arguments - /// - /// * `sessions` - The sessions that should be stored. - async fn save_inbound_group_sessions(&self, session: &[InboundGroupSession]) -> Result<()>; - /// Get the inbound group session from our store. /// /// # Arguments @@ -312,20 +392,6 @@ pub trait CryptoStore: Debug { /// * `dirty` - Should the user be also marked for a key query. async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result; - /// Save the given devices in the store. - /// - /// # Arguments - /// - /// * `device` - The device that should be stored. - async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()>; - - /// Delete the given device from the store. - /// - /// # Arguments - /// - /// * `device` - The device that should be stored. - async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()>; - /// Get the device for the given user with the given device id. /// /// # Arguments @@ -344,14 +410,10 @@ pub trait CryptoStore: Debug { /// # Arguments /// /// * `user_id` - The user for which we should get all the devices. - async fn get_user_devices(&self, user_id: &UserId) -> Result; - - /// Save the given user identities in the store. - /// - /// # Arguments - /// - /// * `identities` - The identities that should be saved in the store. - async fn save_user_identities(&self, identities: &[UserIdentities]) -> Result<()>; + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result>; /// Get the user identity that is attached to the given user id. /// diff --git a/matrix_sdk_crypto/src/store/pickle_key.rs b/matrix_sdk_crypto/src/store/pickle_key.rs new file mode 100644 index 000000000..c49f6352d --- /dev/null +++ b/matrix_sdk_crypto/src/store/pickle_key.rs @@ -0,0 +1,206 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::TryFrom; + +use aes_gcm::{ + aead::{generic_array::GenericArray, Aead, NewAead}, + Aes256Gcm, Error as DecryptionError, +}; +use getrandom::getrandom; +use hmac::Hmac; +use olm_rs::PicklingMode; +use pbkdf2::pbkdf2; +use sha2::Sha256; +use zeroize::{Zeroize, Zeroizing}; + +use serde::{Deserialize, Serialize}; + +const KEY_SIZE: usize = 32; +const NONCE_SIZE: usize = 12; +const KDF_SALT_SIZE: usize = 32; +const KDF_ROUNDS: u32 = 10000; + +/// Version specific info for the key derivation method that is used. +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub enum KdfInfo { + Pbkdf2 { + /// The number of PBKDF rounds that were used when deriving the AES key. + rounds: u32, + }, +} + +/// Version specific info for encryption method that is used to encrypt our +/// pickle key. +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub enum CipherTextInfo { + Aes256Gcm { + /// The nonce that was used to encrypt the ciphertext. + nonce: Vec, + /// The encrypted pickle key. + ciphertext: Vec, + }, +} + +/// An encrypted version of our pickle key, this can be safely stored in a +/// database. +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct EncryptedPickleKey { + /// Info about the key derivation method that was used to expand the + /// passphrase into an encryption key. + pub kdf_info: KdfInfo, + /// The ciphertext with it's accompanying additional data that is needed to + /// decrypt the pickle key. + pub ciphertext_info: CipherTextInfo, + /// The salt that was used when the passphrase was expanded into a AES key. + kdf_salt: Vec, +} + +/// A pickle key that will be used to encrypt all the private keys for Olm. +/// +/// Olm uses AES256 to encrypt accounts, sessions, inbound group sessions. We +/// also implement our own pickling for the cross-signing types using +/// AES256-GCM so the key sizes match. +#[derive(Debug, Zeroize, PartialEq)] +pub struct PickleKey { + aes256_key: Vec, +} + +impl Default for PickleKey { + fn default() -> Self { + let mut key = vec![0u8; KEY_SIZE]; + getrandom(&mut key).expect("Can't generate new pickle key"); + + Self { aes256_key: key } + } +} + +impl TryFrom> for PickleKey { + type Error = (); + fn try_from(value: Vec) -> Result { + if value.len() != KEY_SIZE { + Err(()) + } else { + Ok(Self { aes256_key: value }) + } + } +} + +impl PickleKey { + /// Generate a new random pickle key. + pub fn new() -> Self { + Default::default() + } + + fn expand_key(passphrase: &str, salt: &[u8], rounds: u32) -> Zeroizing> { + let mut key = Zeroizing::from(vec![0u8; KEY_SIZE]); + pbkdf2::>(passphrase.as_bytes(), &salt, rounds, &mut *key); + key + } + + /// Get a `PicklingMode` version of this pickle key. + pub fn pickle_mode(&self) -> PicklingMode { + PicklingMode::Encrypted { + key: self.aes256_key.clone(), + } + } + + /// Get the raw AES256 key. + pub fn key(&self) -> &[u8] { + &self.aes256_key + } + + /// Encrypt and export our pickle key using the given passphrase. + /// + /// # Arguments + /// + /// * `passphrase` - The passphrase that should be used to encrypt the + /// pickle key. + pub fn encrypt(&self, passphrase: &str) -> EncryptedPickleKey { + let mut salt = vec![0u8; KDF_SALT_SIZE]; + getrandom(&mut salt).expect("Can't generate new random pickle key"); + + let key = PickleKey::expand_key(passphrase, &salt, KDF_ROUNDS); + let key = GenericArray::from_slice(key.as_ref()); + let cipher = Aes256Gcm::new(&key); + + let mut nonce = vec![0u8; NONCE_SIZE]; + getrandom(&mut nonce).expect("Can't generate new random nonce for the pickle key"); + + let ciphertext = cipher + .encrypt( + &GenericArray::from_slice(nonce.as_ref()), + self.aes256_key.as_slice(), + ) + .expect("Can't encrypt pickle key"); + + EncryptedPickleKey { + kdf_info: KdfInfo::Pbkdf2 { rounds: KDF_ROUNDS }, + kdf_salt: salt, + ciphertext_info: CipherTextInfo::Aes256Gcm { nonce, ciphertext }, + } + } + + /// Restore a pickle key from an encrypted export. + /// + /// # Arguments + /// + /// * `passphrase` - The passphrase that should be used to encrypt the + /// pickle key. + /// + /// * `encrypted` - The exported and encrypted version of the pickle key. + pub fn from_encrypted( + passphrase: &str, + encrypted: EncryptedPickleKey, + ) -> Result { + let key = match encrypted.kdf_info { + KdfInfo::Pbkdf2 { rounds } => Self::expand_key(passphrase, &encrypted.kdf_salt, rounds), + }; + + let key = GenericArray::from_slice(key.as_ref()); + + let decrypted = match encrypted.ciphertext_info { + CipherTextInfo::Aes256Gcm { nonce, ciphertext } => { + let cipher = Aes256Gcm::new(&key); + let nonce = GenericArray::from_slice(&nonce); + cipher.decrypt(nonce, ciphertext.as_ref())? + } + }; + + Ok(Self { + aes256_key: decrypted, + }) + } +} + +#[cfg(test)] +mod test { + use super::PickleKey; + + #[test] + fn generating() { + PickleKey::new(); + } + + #[test] + fn encrypting() { + let passphrase = "it's a secret to everybody"; + let pickle_key = PickleKey::new(); + + let encrypted = pickle_key.encrypt(passphrase); + let decrypted = PickleKey::from_encrypted(passphrase, encrypted).unwrap(); + + assert_eq!(pickle_key, decrypted); + } +} diff --git a/matrix_sdk_crypto/src/store/sqlite.rs b/matrix_sdk_crypto/src/store/sqlite.rs index fe67e2d61..3df06ab25 100644 --- a/matrix_sdk_crypto/src/store/sqlite.rs +++ b/matrix_sdk_crypto/src/store/sqlite.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, convert::TryFrom, path::{Path, PathBuf}, result::Result as StdResult, @@ -25,28 +25,32 @@ use dashmap::DashSet; use matrix_sdk_common::{ api::r0::keys::{CrossSigningKey, KeyUsage}, identifiers::{ - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, UserId, + DeviceId, DeviceIdBox, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, RoomId, + UserId, }, instant::Duration, locks::Mutex, }; -use sqlx::{query, query_as, sqlite::SqliteQueryAs, Connect, Executor, SqliteConnection}; -use url::Url; -use zeroize::Zeroizing; +use sqlx::{query, query_as, sqlite::SqliteConnectOptions, Connection, Executor, SqliteConnection}; use super::{ - caches::{DeviceStore, GroupSessionStore, ReadOnlyUserDevices, SessionStore}, - CryptoStore, CryptoStoreError, Result, + caches::SessionStore, + pickle_key::{EncryptedPickleKey, PickleKey}, + Changes, CryptoStore, CryptoStoreError, Result, }; use crate::{ identities::{LocalTrust, OwnUserIdentity, ReadOnlyDevice, UserIdentities, UserIdentity}, olm::{ AccountPickle, IdentityKeys, InboundGroupSession, InboundGroupSessionPickle, - PickledAccount, PickledInboundGroupSession, PickledSession, PicklingMode, ReadOnlyAccount, - Session, SessionPickle, + PickledAccount, PickledCrossSigningIdentity, PickledInboundGroupSession, PickledSession, + PicklingMode, PrivateCrossSigningIdentity, ReadOnlyAccount, Session, SessionPickle, }, }; +/// This needs to be 32 bytes long since AES-GCM requires it, otherwise we will +/// panic once we try to pickle a Signing object. +const DEFAULT_PICKLE: &str = "DEFAULT_PICKLE_PASSPHRASE_123456"; + /// SQLite based implementation of a `CryptoStore`. #[derive(Clone)] #[cfg_attr(feature = "docs", doc(cfg(r#sqlite_cryptostore)))] @@ -57,13 +61,11 @@ pub struct SqliteStore { path: Arc, sessions: SessionStore, - inbound_group_sessions: GroupSessionStore, - devices: DeviceStore, tracked_users: Arc>, users_for_key_query: Arc>, connection: Arc>, - pickle_passphrase: Arc>>, + pickle_key: Arc, } #[derive(Clone)] @@ -122,44 +124,43 @@ impl SqliteStore { path: P, passphrase: &str, ) -> Result { - SqliteStore::open_helper( - user_id, - device_id, - path, - Some(Zeroizing::new(passphrase.to_owned())), - ) - .await - } - - fn path_to_url(path: &Path) -> Result { - // TODO this returns an empty error if the path isn't absolute. - let url = Url::from_directory_path(path).expect("Invalid path"); - Ok(url.join(DATABASE_NAME)?) + SqliteStore::open_helper(user_id, device_id, path, Some(passphrase)).await } async fn open_helper>( user_id: &UserId, device_id: &DeviceId, path: P, - passphrase: Option>, + passphrase: Option<&str>, ) -> Result { - let url = SqliteStore::path_to_url(path.as_ref())?; - let connection = SqliteConnection::connect(url.as_ref()).await?; + let path = path.as_ref().join(DATABASE_NAME); + let options = SqliteConnectOptions::new() + .foreign_keys(true) + .create_if_missing(true) + .read_only(false) + .filename(&path); + + let mut connection = SqliteConnection::connect_with(&options).await?; + Self::create_tables(&mut connection).await?; + + let pickle_key = if let Some(passphrase) = passphrase { + Self::get_or_create_pickle_key(user_id, device_id, &passphrase, &mut connection).await? + } else { + PickleKey::try_from(DEFAULT_PICKLE.as_bytes().to_vec()) + .expect("Can't create default pickle key") + }; let store = SqliteStore { user_id: Arc::new(user_id.to_owned()), device_id: Arc::new(device_id.into()), account_info: Arc::new(SyncMutex::new(None)), sessions: SessionStore::new(), - inbound_group_sessions: GroupSessionStore::new(), - devices: DeviceStore::new(), - path: Arc::new(path.as_ref().to_owned()), + path: Arc::new(path), connection: Arc::new(Mutex::new(connection)), - pickle_passphrase: Arc::new(passphrase), tracked_users: Arc::new(DashSet::new()), users_for_key_query: Arc::new(DashSet::new()), + pickle_key: Arc::new(pickle_key), }; - store.create_tables().await?; Ok(store) } @@ -172,8 +173,7 @@ impl SqliteStore { .map(|i| i.account_id) } - async fn create_tables(&self) -> Result<()> { - let mut connection = self.connection.lock().await; + async fn create_tables(connection: &mut SqliteConnection) -> Result<()> { connection .execute( r#" @@ -190,6 +190,37 @@ impl SqliteStore { ) .await?; + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS private_identities ( + "id" INTEGER NOT NULL PRIMARY KEY, + "account_id" INTEGER NOT NULL, + "user_id" TEXT NOT NULL, + "pickle" TEXT NOT NULL, + "shared" INTEGER NOT NULL, + FOREIGN KEY ("account_id") REFERENCES "accounts" ("id") + ON DELETE CASCADE + UNIQUE(account_id, user_id) + ); + "#, + ) + .await?; + + connection + .execute( + r#" + CREATE TABLE IF NOT EXISTS pickle_keys ( + "id" INTEGER NOT NULL PRIMARY KEY, + "user_id" TEXT NOT NULL, + "device_id" TEXT NOT NULL, + "key" TEXT NOT NULL, + UNIQUE(user_id,device_id) + ); + "#, + ) + .await?; + connection .execute( r#" @@ -463,11 +494,68 @@ impl SqliteStore { Ok(()) } - async fn lazy_load_sessions(&self, sender_key: &str) -> Result<()> { + async fn save_pickle_key( + user_id: &UserId, + device_id: &DeviceId, + key: EncryptedPickleKey, + connection: &mut SqliteConnection, + ) -> Result<()> { + let key = serde_json::to_string(&key)?; + + query( + "INSERT INTO pickle_keys ( + user_id, device_id, key + ) VALUES (?1, ?2, ?3) + ON CONFLICT(user_id, device_id) DO UPDATE SET + key = excluded.key + ", + ) + .bind(user_id.as_str()) + .bind(device_id.as_str()) + .bind(key) + .execute(&mut *connection) + .await?; + + Ok(()) + } + + async fn get_or_create_pickle_key( + user_id: &UserId, + device_id: &DeviceId, + passphrase: &str, + connection: &mut SqliteConnection, + ) -> Result { + let row: Option<(String,)> = + query_as("SELECT key FROM pickle_keys WHERE user_id = ? and device_id = ?") + .bind(user_id.as_str()) + .bind(device_id.as_str()) + .fetch_optional(&mut *connection) + .await?; + + Ok(if let Some(row) = row { + let encrypted: EncryptedPickleKey = serde_json::from_str(&row.0)?; + PickleKey::from_encrypted(passphrase, encrypted) + .map_err(|_| CryptoStoreError::UnpicklingError)? + } else { + let key = PickleKey::new(); + let encrypted = key.encrypt(passphrase); + Self::save_pickle_key(user_id, device_id, encrypted, connection).await?; + + key + }) + } + + async fn lazy_load_sessions( + &self, + connection: &mut SqliteConnection, + sender_key: &str, + ) -> Result<()> { let loaded_sessions = self.sessions.get(sender_key).is_some(); if !loaded_sessions { - let sessions = self.load_sessions_for(sender_key).await?; + let sessions = self + .load_sessions_for_helper(connection, sender_key) + .await?; if !sessions.is_empty() { self.sessions.set_for_sender(sender_key, sessions); @@ -477,20 +565,33 @@ impl SqliteStore { Ok(()) } - async fn get_sessions_for(&self, sender_key: &str) -> Result>>>> { - self.lazy_load_sessions(sender_key).await?; + async fn get_sessions_for( + &self, + connection: &mut SqliteConnection, + sender_key: &str, + ) -> Result>>>> { + self.lazy_load_sessions(connection, sender_key).await?; Ok(self.sessions.get(sender_key)) } + #[cfg(test)] async fn load_sessions_for(&self, sender_key: &str) -> Result> { + let mut connection = self.connection.lock().await; + self.load_sessions_for_helper(&mut connection, sender_key) + .await + } + + async fn load_sessions_for_helper( + &self, + connection: &mut SqliteConnection, + sender_key: &str, + ) -> Result> { let account_info = self .account_info .lock() .unwrap() .clone() .ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - let mut rows: Vec<(String, String, String, String)> = query_as( "SELECT pickle, sender_key, creation_time, last_use_time FROM sessions WHERE account_id = ? and sender_key = ?", @@ -526,7 +627,113 @@ impl SqliteStore { .collect::>>()?) } - async fn load_inbound_group_sessions(&self) -> Result<()> { + async fn load_inbound_session_data( + &self, + connection: &mut SqliteConnection, + session_row_id: i64, + pickle: String, + sender_key: String, + room_id: RoomId, + imported: bool, + ) -> Result { + let key_rows: Vec<(String, String)> = + query_as("SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?") + .bind(session_row_id) + .fetch_all(&mut *connection) + .await?; + + let claimed_keys: BTreeMap = key_rows + .into_iter() + .filter_map(|row| { + let algorithm = row.0.parse::().ok()?; + let key = row.1; + + Some((algorithm, key)) + }) + .collect(); + + let mut chain_rows: Vec<(String,)> = + query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?") + .bind(session_row_id) + .fetch_all(&mut *connection) + .await?; + + let chains: Vec = chain_rows.drain(..).map(|r| r.0).collect(); + + let chains = if chains.is_empty() { + None + } else { + Some(chains) + }; + + let pickle = PickledInboundGroupSession { + pickle: InboundGroupSessionPickle::from(pickle), + sender_key, + signing_key: claimed_keys, + room_id, + forwarding_chains: chains, + imported, + }; + + Ok(InboundGroupSession::from_pickle( + pickle, + self.get_pickle_mode(), + )?) + } + + async fn load_inbound_group_session_helper( + &self, + room_id: &RoomId, + sender_key: &str, + session_id: &str, + ) -> Result> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(i64, String, bool)> = query_as( + "SELECT id, pickle, imported + FROM inbound_group_sessions + WHERE ( + account_id = ? and + room_id = ? and + sender_key = ? and + session_id = ? + )", + ) + .bind(account_id) + .bind(room_id.as_str()) + .bind(sender_key) + .bind(session_id) + .fetch_optional(&mut *connection) + .await?; + + let row = if let Some(r) = row { + r + } else { + return Ok(None); + }; + + let session_row_id = row.0; + let pickle = row.1; + let imported = row.2; + + let session = self + .load_inbound_session_data( + &mut connection, + session_row_id, + pickle, + sender_key.to_owned(), + room_id.to_owned(), + imported, + ) + .await?; + + Ok(Some(session)) + } + + async fn load_inbound_group_sessions(&self) -> Result> { + let mut sessions = Vec::new(); + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -542,57 +749,24 @@ impl SqliteStore { let session_row_id = row.0; let pickle = row.1; let sender_key = row.2; - let room_id = row.3; + let room_id = RoomId::try_from(row.3)?; let imported = row.4; - let key_rows: Vec<(String, String)> = query_as( - "SELECT algorithm, key FROM group_session_claimed_keys WHERE session_id = ?", - ) - .bind(session_row_id) - .fetch_all(&mut *connection) - .await?; - - let claimed_keys: BTreeMap = key_rows - .into_iter() - .filter_map(|row| { - let algorithm = row.0.parse::().ok()?; - let key = row.1; - - Some((algorithm, key)) - }) - .collect(); - - let mut chain_rows: Vec<(String,)> = - query_as("SELECT key, key FROM group_session_chains WHERE session_id = ?") - .bind(session_row_id) - .fetch_all(&mut *connection) - .await?; - - let chains: Vec = chain_rows.drain(..).map(|r| r.0).collect(); - - let chains = if chains.is_empty() { - None - } else { - Some(chains) - }; - - let pickle = PickledInboundGroupSession { - pickle: InboundGroupSessionPickle::from(pickle), - sender_key, - signing_key: claimed_keys, - room_id: RoomId::try_from(room_id)?, - forwarding_chains: chains, - imported, - }; - - self.inbound_group_sessions - .add(InboundGroupSession::from_pickle( + let session = self + .load_inbound_session_data( + &mut connection, + session_row_id, pickle, - self.get_pickle_mode(), - )?); + sender_key, + room_id.to_owned(), + imported, + ) + .await?; + + sessions.push(session); } - Ok(()) + Ok(sessions) } async fn save_tracked_user(&self, user: &UserId, dirty: bool) -> Result<()> { @@ -647,119 +821,176 @@ impl SqliteStore { Ok(()) } - async fn load_devices(&self) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; + async fn load_device_data( + &self, + connection: &mut SqliteConnection, + device_row_id: i64, + user_id: &UserId, + device_id: DeviceIdBox, + trust_state: LocalTrust, + display_name: Option, + ) -> Result { + let algorithm_rows: Vec<(String,)> = + query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; - let rows: Vec<(i64, String, String, Option, i64)> = query_as( - "SELECT id, user_id, device_id, display_name, trust_state - FROM devices WHERE account_id = ?", + let algorithms = algorithm_rows + .iter() + .map(|row| { + let algorithm: &str = &row.0; + EventEncryptionAlgorithm::from(algorithm) + }) + .collect::>(); + + let key_rows: Vec<(String, String)> = + query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") + .bind(device_row_id) + .fetch_all(&mut *connection) + .await?; + + let keys: BTreeMap = key_rows + .into_iter() + .filter_map(|row| { + let algorithm = row.0.parse::().ok()?; + let key = row.1; + + Some((DeviceKeyId::from_parts(algorithm, &device_id), key)) + }) + .collect(); + + let signature_rows: Vec<(String, String, String)> = query_as( + "SELECT user_id, key_algorithm, signature + FROM device_signatures WHERE device_id = ?", ) - .bind(account_id) + .bind(device_row_id) .fetch_all(&mut *connection) .await?; - for row in rows { - let device_row_id = row.0; - let user_id: &str = &row.1; - let user_id = if let Ok(u) = UserId::try_from(user_id) { + let mut signatures: BTreeMap> = BTreeMap::new(); + + for row in signature_rows { + let user_id = if let Ok(u) = UserId::try_from(&*row.0) { u } else { continue; }; - let device_id = &row.2.to_string(); - let display_name = &row.3; - let trust_state = LocalTrust::from(row.4); + let key_algorithm = if let Ok(k) = row.1.parse::() { + k + } else { + continue; + }; - let algorithm_rows: Vec<(String,)> = - query_as("SELECT algorithm FROM algorithms WHERE device_id = ?") - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; + let signature = row.2; - let algorithms = algorithm_rows - .iter() - .map(|row| { - let algorithm: &str = &row.0; - EventEncryptionAlgorithm::from(algorithm) - }) - .collect::>(); - - let key_rows: Vec<(String, String)> = - query_as("SELECT algorithm, key FROM device_keys WHERE device_id = ?") - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let keys: BTreeMap = key_rows - .into_iter() - .filter_map(|row| { - let algorithm = row.0.parse::().ok()?; - let key = row.1; - - Some(( - DeviceKeyId::from_parts(algorithm, device_id.as_str().into()), - key, - )) - }) - .collect(); - - let signature_rows: Vec<(String, String, String)> = query_as( - "SELECT user_id, key_algorithm, signature - FROM device_signatures WHERE device_id = ?", - ) - .bind(device_row_id) - .fetch_all(&mut *connection) - .await?; - - let mut signatures: BTreeMap> = BTreeMap::new(); - - for row in signature_rows { - let user_id = if let Ok(u) = UserId::try_from(&*row.0) { - u - } else { - continue; - }; - - let key_algorithm = if let Ok(k) = row.1.parse::() { - k - } else { - continue; - }; - - let signature = row.2; - - signatures - .entry(user_id) - .or_insert_with(BTreeMap::new) - .insert( - DeviceKeyId::from_parts(key_algorithm, device_id.as_str().into()), - signature.to_owned(), - ); - } - - let device = ReadOnlyDevice::new( - user_id, - device_id.as_str().into(), - display_name.clone(), - trust_state, - algorithms, - keys, - signatures, - ); - - self.devices.add(device); + signatures + .entry(user_id) + .or_insert_with(BTreeMap::new) + .insert( + DeviceKeyId::from_parts(key_algorithm, device_id.as_str().into()), + signature.to_owned(), + ); } - Ok(()) + Ok(ReadOnlyDevice::new( + user_id.to_owned(), + device_id, + display_name.clone(), + trust_state, + algorithms, + keys, + signatures, + )) } - async fn save_device_helper(&self, device: ReadOnlyDevice) -> Result<()> { + async fn get_single_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; + let row: Option<(i64, Option, i64)> = query_as( + "SELECT id, display_name, trust_state + FROM devices WHERE account_id = ? and user_id = ? and device_id = ?", + ) + .bind(account_id) + .bind(user_id.as_str()) + .bind(device_id.as_str()) + .fetch_optional(&mut *connection) + .await?; + + let row = if let Some(r) = row { + r + } else { + return Ok(None); + }; + + let device_row_id = row.0; + let display_name = row.1; + let trust_state = LocalTrust::from(row.2); + let device = self + .load_device_data( + &mut connection, + device_row_id, + user_id, + device_id.into(), + trust_state, + display_name, + ) + .await?; + + Ok(Some(device)) + } + + async fn load_devices(&self, user_id: &UserId) -> Result> { + let mut devices = HashMap::new(); + + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let mut rows: Vec<(i64, String, Option, i64)> = query_as( + "SELECT id, device_id, display_name, trust_state + FROM devices WHERE account_id = ? and user_id = ?", + ) + .bind(account_id) + .bind(user_id.as_str()) + .fetch_all(&mut *connection) + .await?; + + for row in rows.drain(..) { + let device_row_id = row.0; + let device_id: DeviceIdBox = row.1.into(); + let display_name = row.2; + let trust_state = LocalTrust::from(row.3); + + let device = self + .load_device_data( + &mut connection, + device_row_id, + user_id, + device_id.clone(), + trust_state, + display_name, + ) + .await?; + + devices.insert(device_id, device); + } + + Ok(devices) + } + + async fn save_device_helper( + &self, + connection: &mut SqliteConnection, + device: ReadOnlyDevice, + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + query( "INSERT INTO devices ( account_id, user_id, device_id, @@ -837,12 +1068,11 @@ impl SqliteStore { } fn get_pickle_mode(&self) -> PicklingMode { - match &*self.pickle_passphrase { - Some(p) => PicklingMode::Encrypted { - key: p.as_bytes().to_vec(), - }, - None => PicklingMode::Unencrypted, - } + self.pickle_key.pickle_mode() + } + + fn get_pickle_key(&self) -> &[u8] { + self.pickle_key.key() } async fn save_inbound_group_session_helper( @@ -1108,10 +1338,140 @@ impl SqliteStore { Ok(()) } - async fn save_user_helper(&self, user: &UserIdentities) -> Result<()> { + #[cfg(test)] + async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { + let mut connection = self.connection.lock().await; + let mut transaction = connection.begin().await?; + + self.save_sessions_helper(&mut transaction, sessions) + .await?; + transaction.commit().await?; + + Ok(()) + } + + async fn save_sessions_helper( + &self, + connection: &mut SqliteConnection, + sessions: &[Session], + ) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + for session in sessions { + self.lazy_load_sessions(connection, &session.sender_key) + .await?; + } + + for session in sessions { + self.sessions.add(session.clone()).await; + + let pickle = session.pickle(self.get_pickle_mode()).await; + + let session_id = session.session_id(); + let creation_time = serde_json::to_string(&pickle.creation_time)?; + let last_use_time = serde_json::to_string(&pickle.last_use_time)?; + + query( + "REPLACE INTO sessions ( + session_id, account_id, creation_time, last_use_time, sender_key, pickle + ) VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind(&session_id) + .bind(&account_id) + .bind(&*creation_time) + .bind(&*last_use_time) + .bind(&pickle.sender_key) + .bind(&pickle.pickle.as_str()) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + + async fn save_devices( + &self, + mut connection: &mut SqliteConnection, + devices: &[ReadOnlyDevice], + ) -> Result<()> { + for device in devices { + self.save_device_helper(&mut connection, device.clone()) + .await? + } + + Ok(()) + } + + async fn delete_devices( + &self, + connection: &mut SqliteConnection, + devices: &[ReadOnlyDevice], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for device in devices { + query( + "DELETE FROM devices + WHERE account_id = ?1 and user_id = ?2 and device_id = ?3 + ", + ) + .bind(account_id) + .bind(&device.user_id().to_string()) + .bind(device.device_id().as_str()) + .execute(&mut *connection) + .await?; + } + + Ok(()) + } + + #[cfg(test)] + async fn save_inbound_group_sessions_test( + &self, + sessions: &[InboundGroupSession], + ) -> Result<()> { let mut connection = self.connection.lock().await; + let mut transaction = connection.begin().await?; + + self.save_inbound_group_sessions(&mut transaction, sessions) + .await?; + + transaction.commit().await?; + Ok(()) + } + + async fn save_inbound_group_sessions( + &self, + connection: &mut SqliteConnection, + sessions: &[InboundGroupSession], + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + + for session in sessions { + self.save_inbound_group_session_helper(account_id, connection, session) + .await?; + } + + Ok(()) + } + + async fn save_user_identities( + &self, + mut connection: &mut SqliteConnection, + users: &[UserIdentities], + ) -> Result<()> { + for user in users { + self.save_user_helper(&mut connection, user).await?; + } + Ok(()) + } + + async fn save_user_helper( + &self, + mut connection: &mut SqliteConnection, + user: &UserIdentities, + ) -> Result<()> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; query("REPLACE INTO users (account_id, user_id) VALUES (?1, ?2)") .bind(account_id) @@ -1202,16 +1562,14 @@ impl CryptoStore for SqliteStore { drop(connection); - self.load_inbound_group_sessions().await?; - self.load_devices().await?; self.load_tracked_users().await?; Ok(result) } async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { - let pickle = account.pickle(self.get_pickle_mode()).await; let mut connection = self.connection.lock().await; + let pickle = account.pickle(self.get_pickle_mode()).await; query( "INSERT INTO accounts ( @@ -1246,57 +1604,91 @@ impl CryptoStore for SqliteStore { Ok(()) } - async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { + async fn save_identity(&self, identity: PrivateCrossSigningIdentity) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - // TODO turn this into a transaction - for session in sessions { - self.lazy_load_sessions(&session.sender_key).await?; - self.sessions.add(session.clone()).await; + let pickle = identity.pickle(self.get_pickle_key()).await?; - let pickle = session.pickle(self.get_pickle_mode()).await; + let mut connection = self.connection.lock().await; - let session_id = session.session_id(); - let creation_time = serde_json::to_string(&pickle.creation_time)?; - let last_use_time = serde_json::to_string(&pickle.last_use_time)?; + query( + "INSERT INTO private_identities ( + account_id, user_id, pickle, shared + ) VALUES (?1, ?2, ?3, ?4) + ON CONFLICT(account_id, user_id) DO UPDATE SET + pickle = excluded.pickle, + shared = excluded.shared + ", + ) + .bind(account_id) + .bind(pickle.user_id.as_str()) + .bind(pickle.pickle) + .bind(pickle.shared) + .execute(&mut *connection) + .await?; - let mut connection = self.connection.lock().await; + Ok(()) + } - query( - "REPLACE INTO sessions ( - session_id, account_id, creation_time, last_use_time, sender_key, pickle - ) VALUES (?, ?, ?, ?, ?, ?)", - ) - .bind(&session_id) - .bind(&account_id) - .bind(&*creation_time) - .bind(&*last_use_time) - .bind(&pickle.sender_key) - .bind(&pickle.pickle.as_str()) - .execute(&mut *connection) - .await?; + async fn load_identity(&self) -> Result> { + let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; + let mut connection = self.connection.lock().await; + + let row: Option<(String, bool)> = query_as( + "SELECT pickle, shared FROM private_identities + WHERE account_id = ?", + ) + .bind(account_id) + .fetch_optional(&mut *connection) + .await?; + + if let Some(row) = row { + let pickle = PickledCrossSigningIdentity { + user_id: (&*self.user_id).clone(), + pickle: row.0, + shared: row.1, + }; + + // TODO remove this unwrap + let identity = PrivateCrossSigningIdentity::from_pickle(pickle, self.get_pickle_key()) + .await + .unwrap(); + + Ok(Some(identity)) + } else { + Ok(None) } + } + + async fn save_changes(&self, changes: Changes) -> Result<()> { + let mut connection = self.connection.lock().await; + let mut transaction = connection.begin().await?; + + self.save_sessions_helper(&mut transaction, &changes.sessions) + .await?; + self.save_inbound_group_sessions(&mut transaction, &changes.inbound_group_sessions) + .await?; + + self.save_devices(&mut transaction, &changes.devices.new) + .await?; + self.save_devices(&mut transaction, &changes.devices.changed) + .await?; + self.delete_devices(&mut transaction, &changes.devices.deleted) + .await?; + + self.save_user_identities(&mut transaction, &changes.identities.new) + .await?; + self.save_user_identities(&mut transaction, &changes.identities.changed) + .await?; + + transaction.commit().await?; Ok(()) } async fn get_sessions(&self, sender_key: &str) -> Result>>>> { - Ok(self.get_sessions_for(sender_key).await?) - } - - async fn save_inbound_group_sessions(&self, sessions: &[InboundGroupSession]) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; - - // FIXME use a transaction here once sqlx gets better support for them. - - for session in sessions { - self.save_inbound_group_session_helper(account_id, &mut connection, session) - .await?; - self.inbound_group_sessions.add(session.clone()); - } - - Ok(()) + Ok(self.get_sessions_for(&mut connection, sender_key).await?) } async fn get_inbound_group_session( @@ -1306,12 +1698,12 @@ impl CryptoStore for SqliteStore { session_id: &str, ) -> Result> { Ok(self - .inbound_group_sessions - .get(room_id, sender_key, session_id)) + .load_inbound_group_session_helper(room_id, sender_key, session_id) + .await?) } async fn get_inbound_group_sessions(&self) -> Result> { - Ok(self.inbound_group_sessions.get_all()) + Ok(self.load_inbound_group_sessions().await?) } fn is_user_tracked(&self, user_id: &UserId) -> bool { @@ -1341,58 +1733,25 @@ impl CryptoStore for SqliteStore { Ok(already_added) } - async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { - // TODO turn this into a bulk transaction. - for device in devices { - self.devices.add(device.clone()); - self.save_device_helper(device.clone()).await? - } - - Ok(()) - } - - async fn delete_device(&self, device: ReadOnlyDevice) -> Result<()> { - let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; - let mut connection = self.connection.lock().await; - - query( - "DELETE FROM devices - WHERE account_id = ?1 and user_id = ?2 and device_id = ?3 - ", - ) - .bind(account_id) - .bind(&device.user_id().to_string()) - .bind(device.device_id().as_str()) - .execute(&mut *connection) - .await?; - - Ok(()) - } - async fn get_device( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result> { - Ok(self.devices.get(user_id, device_id)) + self.get_single_device(user_id, device_id).await } - async fn get_user_devices(&self, user_id: &UserId) -> Result { - Ok(self.devices.user_devices(user_id)) + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result> { + Ok(self.load_devices(user_id).await?) } async fn get_user_identity(&self, user_id: &UserId) -> Result> { self.load_user(user_id).await } - async fn save_user_identities(&self, users: &[UserIdentities]) -> Result<()> { - for user in users { - self.save_user_helper(user).await?; - } - - Ok(()) - } - async fn save_value(&self, key: String, value: String) -> Result<()> { let account_id = self.account_id().ok_or(CryptoStoreError::AccountUnset)?; let mut connection = self.connection.lock().await; @@ -1457,7 +1816,11 @@ mod test { device::test::get_device, user::test::{get_other_identity, get_own_identity}, }, - olm::{GroupSessionKey, InboundGroupSession, ReadOnlyAccount, Session}, + olm::{ + GroupSessionKey, InboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount, + Session, + }, + store::{Changes, DeviceChanges, IdentityChanges}, }; use matrix_sdk_common::{ api::r0::keys::SignedKey, @@ -1549,7 +1912,7 @@ mod test { (alice, session) } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn create_store() { let tmpdir = tempdir().unwrap(); let tmpdir_path = tmpdir.path().to_str().unwrap(); @@ -1558,7 +1921,7 @@ mod test { .expect("Can't create store"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_account() { let (store, _dir) = get_store(None).await; assert!(store.load_account().await.unwrap().is_none()); @@ -1570,7 +1933,7 @@ mod test { .expect("Can't save account"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_account() { let (store, _dir) = get_store(None).await; let account = get_account(); @@ -1586,7 +1949,7 @@ mod test { assert_eq!(account, loaded_account); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_account_with_passphrase() { let (store, _dir) = get_store(Some("secret_passphrase")).await; let account = get_account(); @@ -1602,7 +1965,7 @@ mod test { assert_eq!(account, loaded_account); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_and_share_account() { let (store, _dir) = get_store(None).await; let account = get_account(); @@ -1630,7 +1993,7 @@ mod test { ); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_session() { let (store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; @@ -1645,7 +2008,7 @@ mod test { store.save_sessions(&[session]).await.unwrap(); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_sessions() { let (store, _dir) = get_store(None).await; let (account, session) = get_account_and_session().await; @@ -1664,7 +2027,7 @@ mod test { assert_eq!(&session, loaded_session); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn add_and_save_session() { let (store, dir) = get_store(None).await; let (account, session) = get_account_and_session().await; @@ -1699,7 +2062,7 @@ mod test { assert_eq!(session_id, session.session_id()); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn save_inbound_group_session() { let (account, store, _dir) = get_loaded_store().await; @@ -1714,12 +2077,12 @@ mod test { .expect("Can't create session"); store - .save_inbound_group_sessions(&[session]) + .save_inbound_group_sessions_test(&[session]) .await .expect("Can't save group session"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn load_inbound_group_session() { let (account, store, dir) = get_loaded_store().await; @@ -1740,7 +2103,7 @@ mod test { let session = InboundGroupSession::from_export(export).unwrap(); store - .save_inbound_group_sessions(&[session.clone()]) + .save_inbound_group_sessions_test(&[session.clone()]) .await .expect("Can't save group session"); @@ -1749,7 +2112,6 @@ mod test { .expect("Can't create store"); store.load_account().await.unwrap(); - store.load_inbound_group_sessions().await.unwrap(); let loaded_session = store .get_inbound_group_session(&session.room_id, &session.sender_key, session.session_id()) @@ -1761,7 +2123,7 @@ mod test { assert!(!export.forwarding_curve25519_key_chain.is_empty()) } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn test_tracked_users() { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); @@ -1808,12 +2170,20 @@ mod test { assert!(!store.users_for_key_query().contains(device.user_id())); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn device_saving() { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); - store.save_devices(&[device.clone()]).await.unwrap(); + let changes = Changes { + devices: DeviceChanges { + changed: vec![device.clone()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); drop(store); @@ -1838,17 +2208,34 @@ mod test { assert_eq!(device.keys(), loaded_device.keys()); let user_devices = store.get_user_devices(device.user_id()).await.unwrap(); - assert_eq!(user_devices.keys().next().unwrap(), device.device_id()); - assert_eq!(user_devices.devices().next().unwrap(), &device); + assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id()); + assert_eq!(user_devices.values().next().unwrap(), &device); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn device_deleting() { let (_account, store, dir) = get_loaded_store().await; let device = get_device(); - store.save_devices(&[device.clone()]).await.unwrap(); - store.delete_device(device.clone()).await.unwrap(); + let changes = Changes { + devices: DeviceChanges { + changed: vec![device.clone()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); + + let changes = Changes { + devices: DeviceChanges { + deleted: vec![device.clone()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let store = SqliteStore::open(&alice_id(), &alice_device_id(), dir.path()) .await @@ -1864,7 +2251,7 @@ mod test { assert!(loaded_device.is_none()); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn user_saving() { let dir = tempdir().unwrap(); let tmpdir_path = dir.path().to_str().unwrap(); @@ -1885,8 +2272,16 @@ mod test { let own_identity = get_own_identity(); + let changes = Changes { + identities: IdentityChanges { + changed: vec![own_identity.clone().into()], + ..Default::default() + }, + ..Default::default() + }; + store - .save_user_identities(&[own_identity.clone().into()]) + .save_changes(changes) .await .expect("Can't save identity"); @@ -1913,10 +2308,15 @@ mod test { let other_identity = get_other_identity(); - store - .save_user_identities(&[other_identity.clone().into()]) - .await - .unwrap(); + let changes = Changes { + identities: IdentityChanges { + changed: vec![other_identity.clone().into()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let loaded_user = store .load_user(other_identity.user_id()) @@ -1933,15 +2333,31 @@ mod test { own_identity.mark_as_verified(); - store - .save_user_identities(&[own_identity.into()]) - .await - .unwrap(); + let changes = Changes { + identities: IdentityChanges { + changed: vec![own_identity.into()], + ..Default::default() + }, + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); let loaded_user = store.load_user(&user_id).await.unwrap().unwrap(); assert!(loaded_user.own().unwrap().is_verified()) } - #[tokio::test] + #[tokio::test(threaded_scheduler)] + async fn private_identity_saving() { + let (_, store, _dir) = get_loaded_store().await; + assert!(store.load_identity().await.unwrap().is_none()); + let identity = PrivateCrossSigningIdentity::new((&*store.user_id).clone()).await; + + store.save_identity(identity.clone()).await.unwrap(); + let loaded_identity = store.load_identity().await.unwrap().unwrap(); + assert_eq!(identity.user_id(), loaded_identity.user_id()); + } + + #[tokio::test(threaded_scheduler)] async fn key_value_saving() { let (_, store, _dir) = get_loaded_store().await; let key = "test_key".to_string(); diff --git a/matrix_sdk_crypto/src/verification/machine.rs b/matrix_sdk_crypto/src/verification/machine.rs index ee12d562f..ba098ddcb 100644 --- a/matrix_sdk_crypto/src/verification/machine.rs +++ b/matrix_sdk_crypto/src/verification/machine.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use dashmap::DashMap; +use matrix_sdk_common::locks::Mutex; use tracing::{trace, warn}; use matrix_sdk_common::{ @@ -26,6 +27,7 @@ use matrix_sdk_common::{ use super::sas::{content_to_request, Sas}; use crate::{ + olm::PrivateCrossSigningIdentity, requests::{OutgoingRequest, ToDeviceRequest}, store::{CryptoStore, CryptoStoreError}, ReadOnlyAccount, ReadOnlyDevice, @@ -34,15 +36,21 @@ use crate::{ #[derive(Clone, Debug)] pub struct VerificationMachine { account: ReadOnlyAccount, + user_identity: Arc>, pub(crate) store: Arc>, verifications: Arc>, outgoing_to_device_messages: Arc>, } impl VerificationMachine { - pub(crate) fn new(account: ReadOnlyAccount, store: Arc>) -> Self { + pub(crate) fn new( + account: ReadOnlyAccount, + identity: Arc>, + store: Arc>, + ) -> Self { Self { account, + user_identity: identity, store, verifications: Arc::new(DashMap::new()), outgoing_to_device_messages: Arc::new(DashMap::new()), @@ -194,18 +202,14 @@ impl VerificationMachine { self.receive_event_helper(&s, event); if s.is_done() { - if !s.mark_device_as_verified().await? { - if let Some(r) = s.cancel() { - self.outgoing_to_device_messages.insert( - r.txn_id, - OutgoingRequest { - request_id: r.txn_id, - request: Arc::new(r.into()), - }, - ); - } - } else { - s.mark_identity_as_verified().await?; + if let Some(r) = s.mark_as_done().await? { + self.outgoing_to_device_messages.insert( + r.txn_id, + OutgoingRequest { + request_id: r.txn_id, + request: Arc::new(r.into()), + }, + ); } } }; @@ -228,10 +232,12 @@ mod test { use matrix_sdk_common::{ events::AnyToDeviceEventContent, identifiers::{DeviceId, UserId}, + locks::Mutex, }; use super::{Sas, VerificationMachine}; use crate::{ + olm::PrivateCrossSigningIdentity, requests::OutgoingRequests, store::{CryptoStore, MemoryStore}, verification::test::{get_content_from_request, wrap_any_to_device_content}, @@ -258,18 +264,17 @@ mod test { let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); let bob = ReadOnlyAccount::new(&bob_id(), &bob_device_id()); let store = MemoryStore::new(); - let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); + let bob_store = MemoryStore::new(); let bob_device = ReadOnlyDevice::from_account(&bob).await; let alice_device = ReadOnlyDevice::from_account(&alice).await; - store.save_devices(&[bob_device]).await.unwrap(); - bob_store - .save_devices(&[alice_device.clone()]) - .await - .unwrap(); + store.save_devices(vec![bob_device]).await; + bob_store.save_devices(vec![alice_device.clone()]).await; - let machine = VerificationMachine::new(alice, Arc::new(Box::new(store))); + let bob_store: Arc> = Arc::new(Box::new(bob_store)); + let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); + let machine = VerificationMachine::new(alice, identity, Arc::new(Box::new(store))); let (bob_sas, start_content) = Sas::start(bob, alice_device, bob_store, None); machine .receive_event(&mut wrap_any_to_device_content( @@ -285,8 +290,9 @@ mod test { #[test] fn create() { let alice = ReadOnlyAccount::new(&alice_id(), &alice_device_id()); + let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let store = MemoryStore::new(); - let _ = VerificationMachine::new(alice, Arc::new(Box::new(store))); + let _ = VerificationMachine::new(alice, identity, Arc::new(Box::new(store))); } #[tokio::test] diff --git a/matrix_sdk_crypto/src/verification/sas/mod.rs b/matrix_sdk_crypto/src/verification/sas/mod.rs index 95d615419..b71070a7d 100644 --- a/matrix_sdk_crypto/src/verification/sas/mod.rs +++ b/matrix_sdk_crypto/src/verification/sas/mod.rs @@ -34,7 +34,7 @@ use matrix_sdk_common::{ use crate::{ identities::{LocalTrust, ReadOnlyDevice, UserIdentities}, - store::{CryptoStore, CryptoStoreError}, + store::{Changes, CryptoStore, CryptoStoreError, DeviceChanges}, ReadOnlyAccount, ToDeviceRequest, }; @@ -189,34 +189,64 @@ impl Sas { (content, guard.is_done()) }; - if done { - // TODO move the logic that marks and stores the device into the - // else branch and only after the identity was verified as well. We - // dont' want to verify one without the other. - if !self.mark_device_as_verified().await? { - return Ok(self.cancel()); - } else { - self.mark_identity_as_verified().await?; - } - } + let cancel = if done { + self.mark_as_done().await? + } else { + None + }; - Ok(content.map(|c| { - let content = AnyToDeviceEventContent::KeyVerificationMac(c); - self.content_to_request(content) - })) + if cancel.is_some() { + Ok(cancel) + } else { + Ok(content.map(|c| { + let content = AnyToDeviceEventContent::KeyVerificationMac(c); + self.content_to_request(content) + })) + } } - pub(crate) async fn mark_identity_as_verified(&self) -> Result { + pub(crate) async fn mark_as_done(&self) -> Result, CryptoStoreError> { + if let Some(device) = self.mark_device_as_verified().await? { + let identity = self.mark_identity_as_verified().await?; + + let mut changes = Changes { + devices: DeviceChanges { + changed: vec![device], + ..Default::default() + }, + ..Default::default() + }; + + if let Some(i) = identity { + changes.identities.changed.push(i); + } + + self.store.save_changes(changes).await?; + Ok(None) + } else { + Ok(self.cancel()) + } + } + + pub(crate) async fn mark_identity_as_verified( + &self, + ) -> Result, CryptoStoreError> { // If there wasn't an identity available during the verification flow // return early as there's nothing to do. if self.other_identity.is_none() { - return Ok(false); + return Ok(None); } + // TODO signal an error, e.g. when the identity got deleted so we don't + // verify/save the device either. let identity = self.store.get_user_identity(self.other_user_id()).await?; if let Some(identity) = identity { - if identity.master_key() == self.other_identity.as_ref().unwrap().master_key() { + if self + .other_identity + .as_ref() + .map_or(false, |i| i.master_key() == identity.master_key()) + { if self .verified_identities() .map_or(false, |i| i.contains(&identity)) @@ -228,13 +258,12 @@ impl Sas { if let UserIdentities::Own(i) = &identity { i.mark_as_verified(); - self.store.save_user_identities(&[identity]).await?; } // TODO if we have the private part of the user signing // key we should sign and upload a signature for this // identity. - Ok(true) + Ok(Some(identity)) } else { info!( "The interactive verification process didn't contain a \ @@ -243,7 +272,7 @@ impl Sas { self.verified_identities(), ); - Ok(false) + Ok(None) } } else { warn!( @@ -252,7 +281,7 @@ impl Sas { identity.user_id(), ); - Ok(false) + Ok(None) } } else { info!( @@ -260,11 +289,13 @@ impl Sas { verification was going on.", self.other_user_id(), ); - Ok(false) + Ok(None) } } - pub(crate) async fn mark_device_as_verified(&self) -> Result { + pub(crate) async fn mark_device_as_verified( + &self, + ) -> Result, CryptoStoreError> { let device = self .store .get_device(self.other_user_id(), self.other_device_id()) @@ -283,12 +314,11 @@ impl Sas { ); device.set_trust_state(LocalTrust::Verified); - self.store.save_devices(&[device]).await?; // TODO if this is a device from our own user and we have // the private part of the self signing key, we should sign // the device and upload the signature. - Ok(true) + Ok(Some(device)) } else { info!( "The interactive verification process didn't contain a \ @@ -297,7 +327,7 @@ impl Sas { device.device_id() ); - Ok(false) + Ok(None) } } else { warn!( @@ -306,7 +336,7 @@ impl Sas { device.user_id(), device.device_id() ); - Ok(false) + Ok(None) } } else { let device = self.other_device(); @@ -317,7 +347,7 @@ impl Sas { device.user_id(), device.device_id() ); - Ok(false) + Ok(None) } } @@ -777,12 +807,11 @@ mod test { let bob_device = ReadOnlyDevice::from_account(&bob).await; let alice_store: Arc> = Arc::new(Box::new(MemoryStore::new())); - let bob_store: Arc> = Arc::new(Box::new(MemoryStore::new())); + let bob_store = MemoryStore::new(); - bob_store - .save_devices(&[alice_device.clone()]) - .await - .unwrap(); + bob_store.save_devices(vec![alice_device.clone()]).await; + + let bob_store: Arc> = Arc::new(Box::new(bob_store)); let (alice, content) = Sas::start(alice, bob_device, alice_store, None); let event = wrap_to_device_event(alice.user_id(), content); diff --git a/matrix_sdk_test/Cargo.toml b/matrix_sdk_test/Cargo.toml index 108ad5991..b6a7bcb9c 100644 --- a/matrix_sdk_test/Cargo.toml +++ b/matrix_sdk_test/Cargo.toml @@ -11,7 +11,7 @@ repository = "https://github.com/matrix-org/matrix-rust-sdk" version = "0.1.0" [dependencies] -serde_json = "1.0.58" +serde_json = "1.0.59" http = "0.2.1" matrix-sdk-common = { version = "0.1.0", path = "../matrix_sdk_common" } matrix-sdk-test-macros = { version = "0.1.0", path = "../matrix_sdk_test_macros" }