diff --git a/benchmarks/benches/crypto_bench.rs b/benchmarks/benches/crypto_bench.rs index aa835ff20..737e2ae19 100644 --- a/benchmarks/benches/crypto_bench.rs +++ b/benchmarks/benches/crypto_bench.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; use criterion::*; use matrix_sdk_crypto::{EncryptionSettings, OlmMachine}; @@ -71,7 +71,7 @@ pub fn keys_query(c: &mut Criterion) { }); let dir = tempfile::tempdir().unwrap(); - let store = Box::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); + let store = Arc::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); let machine = runtime.block_on(OlmMachine::with_store(alice_id(), alice_device_id(), store)).unwrap(); @@ -119,7 +119,7 @@ pub fn keys_claiming(c: &mut Criterion) { || { let dir = tempfile::tempdir().unwrap(); let store = - Box::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); + Arc::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); let machine = runtime .block_on(OlmMachine::with_store(alice_id(), alice_device_id(), store)) @@ -181,7 +181,7 @@ pub fn room_key_sharing(c: &mut Criterion) { }) }); let dir = tempfile::tempdir().unwrap(); - let store = Box::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); + let store = Arc::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); let machine = runtime.block_on(OlmMachine::with_store(alice_id(), alice_device_id(), store)).unwrap(); @@ -236,7 +236,7 @@ pub fn devices_missing_sessions_collecting(c: &mut Criterion) { }); let dir = tempfile::tempdir().unwrap(); - let store = Box::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); + let store = Arc::new(SledCryptoStore::open_with_passphrase(dir.path(), None).unwrap()); let machine = runtime.block_on(OlmMachine::with_store(alice_id(), alice_device_id(), store)).unwrap(); diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index 8fdd04cf0..a606c996e 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -37,6 +37,7 @@ http = { version = "0.2.6", optional = true } lru = "0.7.5" matrix-sdk-common = { version = "0.5.0", path = "../matrix-sdk-common" } matrix-sdk-crypto = { version = "0.5.0", path = "../matrix-sdk-crypto", optional = true } +once_cell = "1.10.0" pbkdf2 = { version = "0.11.0", default-features = false, optional = true } rand = { version = "0.8.5", optional = true } serde = { version = "1.0.136", features = ["rc"] } diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index df920c5a0..2e857839d 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -24,8 +24,6 @@ use std::{ops::Deref, result::Result as StdResult}; #[cfg(feature = "experimental-timeline")] use matrix_sdk_common::deserialized_responses::TimelineSlice; -#[cfg(feature = "e2e-encryption")] -use matrix_sdk_common::locks::Mutex; use matrix_sdk_common::{ deserialized_responses::{ AmbiguityChanges, JoinedRoom, LeftRoom, MembersResponse, Rooms, SyncResponse, @@ -41,6 +39,8 @@ use matrix_sdk_crypto::{ OutgoingRequest, ToDeviceRequest, UserDevices, }; #[cfg(feature = "e2e-encryption")] +use once_cell::sync::OnceCell; +#[cfg(feature = "e2e-encryption")] use ruma::{ api::client::keys::claim_keys::v3::Request as KeysClaimRequest, events::{ @@ -86,71 +86,30 @@ pub type Token = String; /// accordingly updates its state. #[derive(Clone)] pub struct BaseClient { - /// The current client session containing our user id, device id and access - /// token. - session: Arc>>, /// The current sync token that should be used for the next sync call. pub(crate) sync_token: Arc>>, /// Database store: Store, + /// The store used for encryption #[cfg(feature = "e2e-encryption")] - olm: Arc>, + crypto_store: Arc, + /// The olm-machine that is created once the + /// [`Session`][crate::session::Session] is set via + /// [`BaseClient::restore_login`] + #[cfg(feature = "e2e-encryption")] + olm_machine: OnceCell, } #[cfg(not(tarpaulin_include))] impl fmt::Debug for BaseClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Client") - .field("session", &self.session) + .field("session", &self.session()) .field("sync_token", &self.sync_token) .finish() } } -#[cfg(feature = "e2e-encryption")] -enum CryptoHolder { - PreSetupStore(Option>), - Olm(Box), -} - -#[cfg(feature = "e2e-encryption")] -impl Default for CryptoHolder { - fn default() -> Self { - CryptoHolder::PreSetupStore(Some(Box::new(MemoryCryptoStore::default()))) - } -} - -#[cfg(feature = "e2e-encryption")] -impl CryptoHolder { - fn new(store: Box) -> Self { - CryptoHolder::PreSetupStore(Some(store)) - } - async fn convert_to_olm(&mut self, session: &Session) -> Result<()> { - if let CryptoHolder::PreSetupStore(store) = self { - *self = CryptoHolder::Olm(Box::new( - OlmMachine::with_store( - &session.user_id, - &session.device_id, - store.take().expect("We always exist"), - ) - .await - .map_err(OlmError::from)?, - )); - Ok(()) - } else { - Err(Error::BadCryptoStoreState) - } - } - - fn machine(&self) -> Option { - if let CryptoHolder::Olm(m) = self { - Some(*m.clone()) - } else { - None - } - } -} - impl BaseClient { /// Create a new default client. pub fn new() -> Self { @@ -166,21 +125,29 @@ impl BaseClient { pub fn with_store_config(config: StoreConfig) -> Self { let store = config.state_store.map(Store::new).unwrap_or_else(Store::open_memory_store); #[cfg(feature = "e2e-encryption")] - let holder = config.crypto_store.map(CryptoHolder::new).unwrap_or_default(); + let crypto_store = + config.crypto_store.unwrap_or_else(|| Box::new(MemoryCryptoStore::default())).into(); BaseClient { - session: store.session.clone(), sync_token: store.sync_token.clone(), store, #[cfg(feature = "e2e-encryption")] - olm: Mutex::new(holder).into(), + crypto_store, + #[cfg(feature = "e2e-encryption")] + olm_machine: Default::default(), } } - /// The current client session containing our user id, device id and access - /// token. - pub fn session(&self) -> &Arc>> { - &self.session + /// Get the user login session. + /// + /// If the client is currently logged in, this will return a + /// [`Session`][crate::session::Session] object which can later be given to + /// [`BaseClient::restore_login`]. + /// + /// Returns a session object if the client is logged in. Otherwise returns + /// `None`. + pub fn session(&self) -> Option<&Session> { + self.store.session() } /// Get a reference to the store. @@ -189,10 +156,8 @@ impl BaseClient { } /// Is the client logged in. - pub async fn logged_in(&self) -> bool { - // TODO turn this into a atomic bool so this method doesn't need to be - // async. - self.session.read().await.is_some() + pub fn logged_in(&self) -> bool { + self.store.session().is_some() } /// Receive a login response and update the session of the client. @@ -225,11 +190,18 @@ impl BaseClient { #[cfg(feature = "e2e-encryption")] { - let mut olm = self.olm.lock().await; - olm.convert_to_olm(&session).await?; - } + let olm_machine = OlmMachine::with_store( + &session.user_id, + &session.device_id, + self.crypto_store.clone(), + ) + .await + .map_err(OlmError::from)?; - *self.session.write().await = Some(session); + if self.olm_machine.set(olm_machine).is_err() { + return Err(Error::BadCryptoStoreState); + } + } Ok(()) } @@ -246,7 +218,7 @@ impl BaseClient { event: &AnySyncMessageLikeEvent, room_id: &RoomId, ) -> Result<()> { - if let Some(olm) = self.olm_machine().await { + if let Some(olm) = self.olm_machine() { olm.receive_unencrypted_verification_event( &event.clone().into_full_event(room_id.to_owned()), ) @@ -338,7 +310,7 @@ impl BaseClient { AnySyncMessageLikeEvent::RoomEncrypted( SyncMessageLikeEvent::Original(encrypted), ) => { - if let Some(olm) = self.olm_machine().await { + if let Some(olm) = self.olm_machine() { if let Ok(decrypted) = olm.decrypt_room_event(encrypted, room_id).await { @@ -599,7 +571,7 @@ impl BaseClient { #[cfg(feature = "e2e-encryption")] let to_device = { - if let Some(o) = self.olm_machine().await { + if let Some(o) = self.olm_machine() { // Let the crypto machine handle the sync response, this // decrypts to-device events, but leaves room events alone. // This makes sure that we have the decryption keys for the room @@ -672,7 +644,7 @@ impl BaseClient { #[cfg(feature = "e2e-encryption")] if room_info.is_encrypted() { - if let Some(o) = self.olm_machine().await { + if let Some(o) = self.olm_machine() { if !room.is_encrypted() { // The room turned on encryption in this sync, we need // to also get all the existing users and mark them for @@ -917,7 +889,7 @@ impl BaseClient { #[cfg(feature = "e2e-encryption")] if room_info.is_encrypted() { - if let Some(o) = self.olm_machine().await { + if let Some(o) = self.olm_machine() { o.update_tracked_users(user_ids.iter().map(Deref::deref)).await } } @@ -983,7 +955,7 @@ impl BaseClient { /// [`mark_request_as_sent`]: #method.mark_request_as_sent #[cfg(feature = "e2e-encryption")] pub async fn outgoing_requests(&self) -> Result, CryptoStoreError> { - match self.olm_machine().await { + match self.olm_machine() { Some(o) => o.outgoing_requests().await, None => Ok(vec![]), } @@ -1004,7 +976,7 @@ impl BaseClient { request_id: &TransactionId, response: impl Into>, ) -> Result<()> { - match self.olm_machine().await { + match self.olm_machine() { Some(o) => Ok(o.mark_request_as_sent(request_id, response).await?), None => Ok(()), } @@ -1018,7 +990,7 @@ impl BaseClient { &self, users: impl Iterator, ) -> Result> { - match self.olm_machine().await { + match self.olm_machine() { Some(o) => Ok(o.get_missing_sessions(users).await?), None => Ok(None), } @@ -1027,7 +999,7 @@ impl BaseClient { /// Get a to-device request that will share a group session for a room. #[cfg(feature = "e2e-encryption")] pub async fn share_group_session(&self, room_id: &RoomId) -> Result>> { - match self.olm_machine().await { + match self.olm_machine() { Some(o) => { let (history_visibility, settings) = self .get_room(room_id) @@ -1070,7 +1042,7 @@ impl BaseClient { room_id: &RoomId, content: impl MessageLikeEventContent, ) -> Result { - match self.olm_machine().await { + match self.olm_machine() { Some(o) => Ok(o.encrypt_room_event(room_id, content).await?), None => panic!("Olm machine wasn't started"), } @@ -1086,7 +1058,7 @@ impl BaseClient { &self, room_id: &RoomId, ) -> Result { - match self.olm_machine().await { + match self.olm_machine() { Some(o) => o.invalidate_group_session(room_id).await, None => Ok(false), } @@ -1126,25 +1098,13 @@ impl BaseClient { user_id: &UserId, device_id: &DeviceId, ) -> Result, CryptoStoreError> { - if let Some(olm) = self.olm_machine().await { + if let Some(olm) = self.olm_machine() { olm.get_device(user_id, device_id).await } else { Ok(None) } } - /// Get the user login session. - /// - /// If the client is currently logged in, this will return a - /// `matrix_sdk::Session` object which can later be given to - /// `restore_login`. - /// - /// Returns a session object if the client is logged in. Otherwise returns - /// `None`. - pub async fn get_session(&self) -> Option { - self.session.read().await.clone() - } - /// Get a map holding all the devices of an user. /// /// This will always return an empty map if the client hasn't been logged @@ -1181,7 +1141,7 @@ impl BaseClient { &self, user_id: &UserId, ) -> Result { - if let Some(olm) = self.olm_machine().await { + if let Some(olm) = self.olm_machine() { Ok(olm.get_user_devices(user_id).await?) } else { // TODO remove this panic. @@ -1191,9 +1151,8 @@ impl BaseClient { /// Get the olm machine. #[cfg(feature = "e2e-encryption")] - pub async fn olm_machine(&self) -> Option { - let olm = self.olm.lock().await; - olm.machine() + pub fn olm_machine(&self) -> Option<&OlmMachine> { + self.olm_machine.get() } /// Get the push rules. @@ -1216,7 +1175,7 @@ impl BaseClient { .transpose()? { Ok(event.content.global) - } else if let Some(session) = self.get_session().await { + } else if let Some(session) = self.session() { Ok(Ruleset::server_default(&session.user_id)) } else { Ok(Ruleset::new()) diff --git a/crates/matrix-sdk-base/src/lib.rs b/crates/matrix-sdk-base/src/lib.rs index f32253e60..3924d80b2 100644 --- a/crates/matrix-sdk-base/src/lib.rs +++ b/crates/matrix-sdk-base/src/lib.rs @@ -41,6 +41,7 @@ pub use client::BaseClient; pub use http; #[cfg(feature = "e2e-encryption")] pub use matrix_sdk_crypto as crypto; +pub use once_cell; pub use rooms::{DisplayName, Room, RoomInfo, RoomMember, RoomType}; pub use store::{StateChanges, StateStore, Store, StoreError}; pub use utils::{ diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index 3d4c99cdd..b7a581cf2 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -28,6 +28,8 @@ use std::{ sync::Arc, }; +use once_cell::sync::OnceCell; + #[cfg(any(test, feature = "testing"))] #[macro_use] pub mod integration_tests; @@ -380,7 +382,7 @@ pub trait StateStore: AsyncTraitDeps { #[derive(Debug, Clone)] pub struct Store { inner: Arc, - pub(crate) session: Arc>>, + pub(crate) session: Arc>, pub(crate) sync_token: Arc>>, rooms: Arc>, stripped_rooms: Arc>, @@ -423,11 +425,17 @@ impl Store { let token = self.get_sync_token().await?; *self.sync_token.write().await = token; - *self.session.write().await = Some(session); + self.session.set(session).expect("A session was already set"); Ok(()) } + /// The current [`Session`] containing our user id, device id and access + /// token. + pub fn session(&self) -> Option<&Session> { + self.session.get() + } + /// Get all the rooms this store knows about. pub fn get_rooms(&self) -> Vec { self.rooms.iter().filter_map(|r| self.get_room(r.key())).collect() @@ -458,8 +466,7 @@ impl Store { /// Lookup the stripped Room for the given RoomId, or create one, if it /// didn't exist yet in the store pub async fn get_or_create_stripped_room(&self, room_id: &RoomId) -> Room { - let session = self.session.read().await; - let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id; + let user_id = &self.session().expect("Creating room while not being logged in").user_id; self.stripped_rooms .entry(room_id.to_owned()) @@ -474,8 +481,7 @@ impl Store { return self.get_or_create_stripped_room(room_id).await; } - let session = self.session.read().await; - let user_id = &session.as_ref().expect("Creating room while not being logged in").user_id; + let user_id = &self.session().expect("Creating room while not being logged in").user_id; self.rooms .entry(room_id.to_owned()) diff --git a/crates/matrix-sdk-crypto-ffi/Cargo.toml b/crates/matrix-sdk-crypto-ffi/Cargo.toml index 6e7952137..b18573963 100644 --- a/crates/matrix-sdk-crypto-ffi/Cargo.toml +++ b/crates/matrix-sdk-crypto-ffi/Cargo.toml @@ -10,8 +10,8 @@ license = "Apache-2.0" publish = false [lib] -crate-type = ["cdylib", "lib"] -name = "matrix_crypto" +crate-type = ["cdylib", "staticlib"] +name = "matrix_crypto_ffi" [dependencies] anyhow = "1.0.57" @@ -55,7 +55,8 @@ default_features = false features = ["rt-multi-thread"] [dependencies.vodozemac] -version = "0.2.0" +git = "https://github.com/matrix-org/vodozemac/" +rev = "d0e744287a14319c2a9148fef3747548c740fc36" [build-dependencies] uniffi_build = { version = "0.17.0", features = ["builtin-bindgen"] } diff --git a/crates/matrix-sdk-crypto-ffi/src/lib.rs b/crates/matrix-sdk-crypto-ffi/src/lib.rs index 7b3d166a4..6ca3080a6 100644 --- a/crates/matrix-sdk-crypto-ffi/src/lib.rs +++ b/crates/matrix-sdk-crypto-ffi/src/lib.rs @@ -48,7 +48,7 @@ pub struct MigrationData { /// The list of Megolm inbound group sessions. inbound_group_sessions: Vec, /// The Olm pickle key that was used to pickle all the Olm objects. - pickle_key: String, + pickle_key: Vec, /// The backup version that is currently active. backup_version: Option, // The backup recovery key, as a base58 encoded string. @@ -521,7 +521,9 @@ mod test { "backed_up":true } ], - "pickle_key":"\u{0011}$xJ_N8$>{\u{0005}iJoF03eBVt\u{000e}rUU\\,GYc7J", + "pickle_key": [17, 36, 120, 74, 95, 78, 56, 36, 62, 123, 5, 105, 74, + 111, 70, 48, 51, 101, 66, 86, 116, 14, 114, 85, 85, + 92, 44, 71, 89, 99, 55, 74], "backup_version":"3", "backup_recovery_key":"EsTHScmRV5oT1WBhe2mj2Gn3odeYantZ4NEk7L51p6L8hrmB", "cross_signing":{ diff --git a/crates/matrix-sdk-crypto-ffi/src/machine.rs b/crates/matrix-sdk-crypto-ffi/src/machine.rs index fbd01ee4f..d59a82e95 100644 --- a/crates/matrix-sdk-crypto-ffi/src/machine.rs +++ b/crates/matrix-sdk-crypto-ffi/src/machine.rs @@ -3,6 +3,7 @@ use std::{ convert::TryInto, io::Cursor, ops::Deref, + sync::Arc, }; use base64::{decode_config, encode, STANDARD_NO_PAD}; @@ -32,7 +33,7 @@ use ruma::{ }, events::{ key::verification::VerificationMethod, room::encrypted::OriginalSyncRoomEncryptedEvent, - AnyMessageLikeEventContent, AnySyncMessageLikeEvent, EventContent, + AnySyncMessageLikeEvent, }, DeviceKeyAlgorithm, EventId, OwnedTransactionId, OwnedUserId, RoomId, UserId, }; @@ -92,7 +93,7 @@ impl OlmMachine { let device_id = device_id.into(); let runtime = Runtime::new().expect("Couldn't create a tokio runtime"); - let store = Box::new( + let store = Arc::new( matrix_sdk_sled::CryptoStore::open_with_passphrase(path, passphrase.as_deref()) .map_err(|e| { match e { @@ -518,12 +519,11 @@ impl OlmMachine { content: &str, ) -> Result { let room_id = RoomId::parse(room_id)?; - let content: Box = serde_json::from_str(content)?; + let content: Value = serde_json::from_str(content)?; - let content = AnyMessageLikeEventContent::from_parts(event_type, &content)?; let encrypted_content = self .runtime - .block_on(self.inner.encrypt_room_event(&room_id, content)) + .block_on(self.inner.encrypt_room_event_raw(&room_id, content, event_type)) .expect("Encrypting an event produced an error"); Ok(serde_json::to_string(&encrypted_content)?) diff --git a/crates/matrix-sdk-crypto-ffi/src/olm.udl b/crates/matrix-sdk-crypto-ffi/src/olm.udl index b7a1ba5f3..d9421f03e 100644 --- a/crates/matrix-sdk-crypto-ffi/src/olm.udl +++ b/crates/matrix-sdk-crypto-ffi/src/olm.udl @@ -451,7 +451,7 @@ dictionary MigrationData { sequence inbound_group_sessions; string? backup_version; string? backup_recovery_key; - string pickle_key; + sequence pickle_key; CrossSigningKeyExport cross_signing; sequence tracked_users; }; diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index a9e0eaa48..11f65b333 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -47,8 +47,23 @@ sha2 = "0.10.2" thiserror = "1.0.30" tracing = "0.1.34" zeroize = { version = "1.3.0", features = ["zeroize_derive"] } -ruma = { version = "0.6.2", features = ["client-api-c", "rand", "unstable-msc2676", "unstable-msc2677"] } -vodozemac = { git = "https://github.com/matrix-org/vodozemac", rev = "e09c93f2c8df9770793abeec57ed984d5e1f3834" } + +[target.'cfg(target_arch = "wasm32")'.dependencies.ruma] +version = "0.6.1" +features = ["client-api-c", "js", "rand", "signatures", "unstable-msc2676", "unstable-msc2677"] + +[target.'cfg(target_arch = "wasm32")'.dependencies.vodozemac] +git = "https://github.com/matrix-org/vodozemac/" +rev = "d0e744287a14319c2a9148fef3747548c740fc36" +features = ["js"] + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.ruma] +version = "0.6.1" +features = ["client-api-c", "rand", "signatures", "unstable-msc2676", "unstable-msc2677"] + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies.vodozemac] +git = "https://github.com/matrix-org/vodozemac/" +rev = "d0e744287a14319c2a9148fef3747548c740fc36" [dev-dependencies] futures = { version = "0.3.21", default-features = false, features = ["executor"] } diff --git a/crates/matrix-sdk-crypto/src/error.rs b/crates/matrix-sdk-crypto/src/error.rs index 6cf1037eb..f6069a727 100644 --- a/crates/matrix-sdk-crypto/src/error.rs +++ b/crates/matrix-sdk-crypto/src/error.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use ruma::{IdParseError, OwnedDeviceId, OwnedRoomId, OwnedUserId}; +use ruma::{signatures::CanonicalJsonError, IdParseError, OwnedDeviceId, OwnedRoomId, OwnedUserId}; use serde_json::Error as SerdeError; use thiserror::Error; @@ -181,7 +181,13 @@ pub enum SignatureError { /// The signed object couldn't be deserialized. #[error(transparent)] - JsonError(#[from] SerdeError), + JsonError(#[from] CanonicalJsonError), +} + +impl From for SignatureError { + fn from(e: SerdeError) -> Self { + CanonicalJsonError::SerDe(e).into() + } } #[derive(Error, Debug)] diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index d843bc1be..8b61ab9db 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -32,7 +32,7 @@ use ruma::{ AnyToDeviceEventContent, }, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, OwnedDeviceId, - OwnedDeviceKeyId, OwnedUserId, UserId, + OwnedDeviceKeyId, UserId, }; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::{json, Value}; @@ -47,7 +47,7 @@ use crate::{ identities::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities}, olm::{InboundGroupSession, Session, VerifyJson}, store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult}, - types::{DeviceKey, DeviceKeys, SignedKey}, + types::{DeviceKey, DeviceKeys, Signatures, SignedKey}, verification::VerificationMachine, OutgoingVerificationRequest, ReadOnlyAccount, Sas, ToDeviceRequest, VerificationRequest, }; @@ -433,7 +433,7 @@ impl ReadOnlyDevice { } /// Get a map containing all the device signatures. - pub fn signatures(&self) -> &BTreeMap> { + pub fn signatures(&self) -> &Signatures { &self.inner.signatures } diff --git a/crates/matrix-sdk-crypto/src/identities/user.rs b/crates/matrix-sdk-crypto/src/identities/user.rs index b64e79a67..1491e5cc1 100644 --- a/crates/matrix-sdk-crypto/src/identities/user.rs +++ b/crates/matrix-sdk-crypto/src/identities/user.rs @@ -28,7 +28,7 @@ use ruma::{ events::{ key::verification::VerificationMethod, room::message::KeyVerificationRequestEventContent, }, - DeviceKeyId, EventId, OwnedDeviceId, OwnedDeviceKeyId, OwnedUserId, RoomId, UserId, + DeviceKeyId, EventId, OwnedDeviceId, OwnedDeviceKeyId, RoomId, UserId, }; use serde::{Deserialize, Serialize}; use serde_json::to_value; @@ -40,7 +40,7 @@ use crate::{ error::SignatureError, olm::VerifyJson, store::{Changes, IdentityChanges}, - types::{CrossSigningKey, DeviceKeys, SigningKey}, + types::{CrossSigningKey, DeviceKeys, Signatures, SigningKey}, verification::VerificationMachine, CryptoStoreError, OutgoingVerificationRequest, ReadOnlyDevice, VerificationRequest, }; @@ -404,7 +404,7 @@ impl MasterPubkey { } /// Get the signatures map of this cross signing key. - pub fn signatures(&self) -> &BTreeMap> { + pub fn signatures(&self) -> &Signatures { &self.0.signatures } diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index 563209496..088c9b62f 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -129,7 +129,7 @@ impl OlmMachine { /// /// * `device_id` - The unique id of the device that owns this machine. pub async fn new(user_id: &UserId, device_id: &DeviceId) -> Self { - let store: Box = Box::new(MemoryStore::new()); + let store: Arc = Arc::new(MemoryStore::new()); OlmMachine::with_store(user_id, device_id, store) .await @@ -139,14 +139,13 @@ impl OlmMachine { fn new_helper( user_id: &UserId, device_id: &DeviceId, - store: Box, + store: Arc, account: ReadOnlyAccount, user_identity: PrivateCrossSigningIdentity, ) -> Self { let user_id: Arc = user_id.into(); let user_identity = Arc::new(Mutex::new(user_identity)); - let store: Arc = store.into(); let verification_machine = VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); let store = @@ -216,7 +215,7 @@ impl OlmMachine { pub async fn with_store( user_id: &UserId, device_id: &DeviceId, - store: Box, + store: Arc, ) -> StoreResult { let account = match store.load_account().await? { Some(a) => { @@ -636,7 +635,7 @@ impl OlmMachine { /// Encrypt a room message for the given room. /// /// Beware that a group session needs to be shared before this method can be - /// called using the [`share_group_session`] method. + /// called using the [`OlmMachine::share_group_session`] method. /// /// # Arguments /// diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index 115d97bb6..7997e55e2 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -35,7 +35,7 @@ use ruma::{ }, AnyToDeviceEvent, OlmV1Keys, }, - serde::{CanonicalJsonValue, Raw}, + serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, EventEncryptionAlgorithm, OwnedDeviceId, OwnedDeviceKeyId, OwnedUserId, RoomId, SecondsSinceUnixEpoch, UInt, UserId, }; @@ -49,8 +49,8 @@ use vodozemac::{ }; use super::{ - EncryptionSettings, InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity, - Session, + utility::SignJson, EncryptionSettings, InboundGroupSession, OutboundGroupSession, + PrivateCrossSigningIdentity, Session, }; use crate::{ error::{EventError, OlmResult, SessionCreationError}, @@ -739,7 +739,7 @@ impl ReadOnlyAccount { (*self.device_id).to_owned(), Self::ALGORITHMS.iter().map(|a| (**a).clone()).collect(), keys, - BTreeMap::new(), + Default::default(), ) } @@ -752,10 +752,15 @@ impl ReadOnlyAccount { // get signed. let json_device_keys = serde_json::to_value(&device_keys).expect("device key is always safe to serialize"); + let signature = self + .sign_json(json_device_keys) + .await + .expect("Newly created device keys can always be signed"); - device_keys.signatures.entry(self.user_id().to_owned()).or_default().insert( + device_keys.signatures.add_signature( + self.user_id().to_owned(), DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id), - self.sign_json(json_device_keys).await.to_base64(), + signature, ); device_keys @@ -773,11 +778,12 @@ impl ReadOnlyAccount { &self, cross_signing_key: &mut CrossSigningKey, ) -> Result<(), SignatureError> { - let signature = self.sign_json(serde_json::to_value(&cross_signing_key)?).await; + let signature = self.sign_json(serde_json::to_value(&cross_signing_key)?).await?; - cross_signing_key.signatures.entry(self.user_id().to_owned()).or_default().insert( + cross_signing_key.signatures.add_signature( + self.user_id().to_owned(), DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()), - signature.to_base64(), + signature, ); Ok(()) @@ -809,19 +815,8 @@ impl ReadOnlyAccount { /// /// * `json` - The value that should be converted into a canonical JSON /// string. - /// - /// # Panic - /// - /// Panics if the json value can't be serialized. - pub async fn sign_json(&self, mut json: Value) -> Ed25519Signature { - let object = json.as_object_mut().expect("Canonical json value isn't an object"); - object.remove("unsigned"); - object.remove("signatures"); - - let canonical_json: CanonicalJsonValue = - json.try_into().expect("Can't canonicalize the json value"); - - self.sign(&canonical_json.to_string()).await + pub async fn sign_json(&self, json: Value) -> Result { + self.inner.lock().await.sign_json(json) } /// Generate, sign and prepare one-time keys to be uploaded. @@ -885,18 +880,16 @@ impl ReadOnlyAccount { SignedKey::new(key.to_owned()) }; - let signature = - self.sign_json(serde_json::to_value(&key).expect("Can't serialize a signed key")).await; + let signature = self + .sign_json(serde_json::to_value(&key).expect("Can't serialize a signed key")) + .await + .expect("Newly created one-time keys can always be signed"); - let signatures = BTreeMap::from([( + key.signatures_mut().add_signature( self.user_id().to_owned(), - BTreeMap::from([( - DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &self.device_id), - signature, - )]), - )]); - - *key.signatures() = signatures; + DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, self.device_id()), + signature, + ); key } @@ -1012,8 +1005,7 @@ impl ReadOnlyAccount { message: &PreKeyMessage, ) -> Result { let their_identity_key = Curve25519PublicKey::from_base64(their_identity_key)?; - let result = - self.inner.lock().await.create_inbound_session(&their_identity_key, message)?; + let result = self.inner.lock().await.create_inbound_session(their_identity_key, message)?; let now = SecondsSinceUnixEpoch::now(); let session_id = result.session.session_id(); diff --git a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs index 9a538653b..56c0da79c 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs @@ -651,13 +651,16 @@ impl PrivateCrossSigningIdentity { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use matrix_sdk_test::async_test; - use ruma::{device_id, user_id, UserId}; + use ruma::{device_id, user_id, DeviceKeyAlgorithm, DeviceKeyId, UserId}; + use serde_json::json; use super::{PrivateCrossSigningIdentity, Signing}; use crate::{ identities::{ReadOnlyDevice, ReadOnlyUserIdentity}, - olm::ReadOnlyAccount, + olm::{utility::SignJson, ReadOnlyAccount}, }; fn user_id() -> &'static UserId { @@ -667,11 +670,25 @@ mod tests { #[test] fn signature_verification() { let signing = Signing::new(); + let user_id = user_id(); + let key_id = DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, "DEVICEID".into()); - let message = "Hello world"; + let json = json!({ + "hello": "world" + }); - let signature = signing.sign(message); - assert!(signing.verify(message, &signature).is_ok()); + let signature = + signing.sign_json(json).expect("We should be able to sign a simple json object"); + let signatures = + BTreeMap::from([(user_id, BTreeMap::from([(key_id.clone(), signature.to_base64())]))]); + + let mut json = json!({ + "hello": "world", + "signatures": signatures, + + }); + + assert!(signing.verify_json(user_id, &key_id, &mut json).is_ok()); } #[test] @@ -763,8 +780,11 @@ mod tests { let master = user_signing.sign_user(&bob_public).unwrap(); - let num_signatures: usize = master.signatures.iter().map(|(_, u)| u.len()).sum(); - assert_eq!(num_signatures, 1, "We're only uploading our own signature"); + assert_eq!( + master.signatures.signature_count(), + 1, + "We're only uploading our own signature" + ); bob_public.master_key = master.into(); diff --git a/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs b/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs index 1f7b137ba..6c445642c 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::BTreeMap, convert::TryInto}; +use std::collections::BTreeMap; -use ruma::{ - encryption::KeyUsage, serde::CanonicalJsonValue, DeviceKeyAlgorithm, DeviceKeyId, OwnedUserId, -}; +use ruma::{encryption::KeyUsage, DeviceKeyAlgorithm, DeviceKeyId, OwnedUserId}; use serde::{Deserialize, Serialize}; use serde_json::{Error as JsonError, Value}; use thiserror::Error; @@ -25,7 +23,8 @@ use vodozemac::{Ed25519PublicKey, Ed25519SecretKey, Ed25519Signature, KeyError}; use crate::{ error::SignatureError, identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, - types::{CrossSigningKey, CrossSigningKeySignatures, DeviceKeys}, + olm::utility::SignJson, + types::{CrossSigningKey, DeviceKeys, Signatures}, utilities::{encode, DecodeError}, ReadOnlyUserIdentity, }; @@ -64,6 +63,12 @@ impl PartialEq for Signing { } } +impl SignJson for Signing { + fn sign_json(&self, value: Value) -> Result { + self.inner.sign_json(value) + } +} + #[derive(PartialEq, Debug)] pub struct MasterSigning { pub inner: Signing, @@ -123,17 +128,14 @@ impl MasterSigning { let json_subkey = serde_json::to_value(&subkey).expect("Can't serialize cross signing key"); let signature = self.inner.sign_json(json_subkey).expect("Can't sign cross signing keys"); - subkey - .signatures - .entry(self.public_key.user_id().to_owned()) - .or_insert_with(BTreeMap::new) - .insert( - DeviceKeyId::from_parts( - DeviceKeyAlgorithm::Ed25519, - self.inner.public_key.to_base64().as_str().into(), - ), - signature.to_base64(), - ); + subkey.signatures.add_signature( + self.public_key.user_id().to_owned(), + DeviceKeyId::from_parts( + DeviceKeyAlgorithm::Ed25519, + self.inner.public_key.to_base64().as_str().into(), + ), + signature, + ); } } @@ -170,22 +172,20 @@ impl UserSigning { pub fn sign_user_helper( &self, user: &ReadOnlyUserIdentity, - ) -> Result { + ) -> Result { let user_master: &CrossSigningKey = user.master_key().as_ref(); let signature = self.inner.sign_json(serde_json::to_value(user_master)?)?; - let mut signatures = BTreeMap::new(); + let mut signatures = Signatures::new(); - signatures - .entry(self.public_key.user_id().to_owned()) - .or_insert_with(BTreeMap::new) - .insert( - DeviceKeyId::from_parts( - DeviceKeyAlgorithm::Ed25519, - self.inner.public_key.to_base64().as_str().into(), - ), - signature.to_base64(), - ); + signatures.add_signature( + self.public_key.user_id().to_owned(), + DeviceKeyId::from_parts( + DeviceKeyAlgorithm::Ed25519, + self.inner.public_key.to_base64().as_str().into(), + ), + signature, + ); Ok(signatures) } @@ -215,19 +215,17 @@ impl SelfSigning { Ok(Self { inner, public_key }) } - pub fn sign_device_helper(&self, value: Value) -> Result { - self.inner.sign_json(value) - } - pub fn sign_device(&self, device_keys: &mut DeviceKeys) -> Result<(), SignatureError> { - let signature = self.sign_device_helper(serde_json::to_value(&device_keys)?)?; + let serialized = serde_json::to_value(&device_keys)?; + let signature = self.inner.sign_json(serialized)?; - device_keys.signatures.entry(self.public_key.user_id().to_owned()).or_default().insert( + device_keys.signatures.add_signature( + self.public_key.user_id().to_owned(), DeviceKeyId::from_parts( DeviceKeyAlgorithm::Ed25519, self.inner.public_key.to_base64().as_str().into(), ), - signature.to_base64(), + signature, ); Ok(()) @@ -308,30 +306,21 @@ impl Signing { self.inner.public_key().into(), )]); - CrossSigningKey::new(user_id, vec![usage], keys, BTreeMap::new()) - } - - #[cfg(test)] - pub fn verify( - &self, - message: &str, - signature: &Ed25519Signature, - ) -> Result<(), SignatureError> { - Ok(self.public_key.verify(message.as_bytes(), signature)?) - } - - pub fn sign_json(&self, mut json: Value) -> Result { - let json_object = json.as_object_mut().ok_or(SignatureError::NotAnObject)?; - let _ = json_object.remove("signatures"); - let _ = json_object.remove("unsigned"); - - let canonical_json: CanonicalJsonValue = - json.try_into().expect("Can't canonicalize the json value"); - - Ok(self.sign(&canonical_json.to_string())) + CrossSigningKey::new(user_id, vec![usage], keys, Default::default()) } pub fn sign(&self, message: &str) -> Ed25519Signature { self.inner.sign(message.as_bytes()) } + + #[cfg(test)] + pub fn verify_json( + &self, + user_id: &ruma::UserId, + key_id: &DeviceKeyId, + message: &mut Value, + ) -> Result<(), SignatureError> { + use crate::olm::VerifyJson; + self.public_key.verify_json(user_id, key_id, message) + } } diff --git a/crates/matrix-sdk-crypto/src/olm/utility.rs b/crates/matrix-sdk-crypto/src/olm/utility.rs index a95ec607f..60b3b54fc 100644 --- a/crates/matrix-sdk-crypto/src/olm/utility.rs +++ b/crates/matrix-sdk-crypto/src/olm/utility.rs @@ -16,9 +16,39 @@ use std::convert::TryInto; use ruma::{serde::CanonicalJsonValue, DeviceKeyAlgorithm, DeviceKeyId, UserId}; use serde_json::Value; +use vodozemac::{olm::Account, Ed25519SecretKey, Ed25519Signature}; use crate::error::SignatureError; +pub trait SignJson { + fn sign_json(&self, value: Value) -> Result; + + fn to_signable_json(mut value: Value) -> Result { + let json_object = value.as_object_mut().ok_or(SignatureError::NotAnObject)?; + let _ = json_object.remove("signatures"); + let _ = json_object.remove("unsigned"); + + let canonical_json: CanonicalJsonValue = value.try_into()?; + Ok(canonical_json.to_string()) + } +} + +impl SignJson for Account { + fn sign_json(&self, value: Value) -> Result { + let serialized = Self::to_signable_json(value)?; + + Ok(self.sign(serialized.as_ref())) + } +} + +impl SignJson for Ed25519SecretKey { + fn sign_json(&self, value: Value) -> Result { + let serialized = Self::to_signable_json(value)?; + + Ok(self.sign(serialized.as_ref())) + } +} + pub trait VerifyJson { /// Verify a signed JSON object. /// diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index c82fe9345..f168ee264 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -22,6 +22,7 @@ //! `wasm-unknown-unknown` an indexeddb store would be needed //! //! ``` +//! # use std::sync::Arc; //! # use matrix_sdk_crypto::{ //! # OlmMachine, //! # store::MemoryStore, @@ -29,7 +30,7 @@ //! # use ruma::{device_id, user_id}; //! # let user_id = user_id!("@example:localhost"); //! # let device_id = device_id!("TEST"); -//! let store = Box::new(MemoryStore::new()); +//! let store = Arc::new(MemoryStore::new()); //! //! let machine = OlmMachine::with_store(user_id, device_id, store); //! ``` diff --git a/crates/matrix-sdk-crypto/src/types/cross_signing_key.rs b/crates/matrix-sdk-crypto/src/types/cross_signing_key.rs index 753b09ff2..3ae3babff 100644 --- a/crates/matrix-sdk-crypto/src/types/cross_signing_key.rs +++ b/crates/matrix-sdk-crypto/src/types/cross_signing_key.rs @@ -26,8 +26,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{value::to_raw_value, Value}; use vodozemac::Ed25519PublicKey; -/// Signatures for a `CrossSigningKey` object. -pub type CrossSigningKeySignatures = BTreeMap>; +use super::Signatures; /// A cross signing key. #[derive(Clone, Debug, Deserialize, Serialize)] @@ -47,8 +46,7 @@ pub struct CrossSigningKey { /// Signatures of the key. /// /// Only optional for master key. - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - pub signatures: CrossSigningKeySignatures, + pub signatures: Signatures, #[serde(flatten)] other: BTreeMap, @@ -61,7 +59,7 @@ impl CrossSigningKey { user_id: OwnedUserId, usage: Vec, keys: BTreeMap, - signatures: CrossSigningKeySignatures, + signatures: Signatures, ) -> Self { Self { user_id, usage, keys, signatures, other: BTreeMap::new() } } @@ -77,7 +75,7 @@ impl CrossSigningKey { /// Currently cross signing keys support an ed25519 keypair. The keys transport /// format is a base64 encoded string, any unknown key type will be left as such /// a string. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum SigningKey { /// The ed25519 cross-signing key. Ed25519(Ed25519PublicKey), @@ -106,8 +104,8 @@ struct CrossSigningKeyHelper { pub user_id: OwnedUserId, pub usage: Vec, pub keys: BTreeMap, - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - pub signatures: CrossSigningKeySignatures, + #[serde(default, skip_serializing_if = "Signatures::is_empty")] + pub signatures: Signatures, #[serde(flatten)] other: BTreeMap, } diff --git a/crates/matrix-sdk-crypto/src/types/device_keys.rs b/crates/matrix-sdk-crypto/src/types/device_keys.rs index aa558f153..50e23147e 100644 --- a/crates/matrix-sdk-crypto/src/types/device_keys.rs +++ b/crates/matrix-sdk-crypto/src/types/device_keys.rs @@ -28,6 +28,8 @@ use serde::{Deserialize, Serialize}; use serde_json::{value::to_raw_value, Value}; use vodozemac::{Curve25519PublicKey, Ed25519PublicKey}; +use super::Signatures; + /// Identity keys for a device. #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(try_from = "DeviceKeyHelper", into = "DeviceKeyHelper")] @@ -49,7 +51,7 @@ pub struct DeviceKeys { pub keys: BTreeMap, /// Signatures for the device key object. - pub signatures: BTreeMap>, + pub signatures: Signatures, /// Additional data added to the device key information by intermediate /// servers, and not covered by the signatures. @@ -68,7 +70,7 @@ impl DeviceKeys { device_id: OwnedDeviceId, algorithms: Vec, keys: BTreeMap, - signatures: BTreeMap>, + signatures: Signatures, ) -> Self { Self { user_id, @@ -115,7 +117,7 @@ impl UnsignedDeviceInfo { /// Currently devices have a curve25519 and ed25519 keypair. The keys transport /// format is a base64 encoded string, any unknown key type will be left as such /// a string. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum DeviceKey { /// The curve25519 device key. Curve25519(Curve25519PublicKey), @@ -154,7 +156,7 @@ struct DeviceKeyHelper { pub device_id: OwnedDeviceId, pub algorithms: Vec, pub keys: BTreeMap, - pub signatures: BTreeMap>, + pub signatures: Signatures, #[serde(default, skip_serializing_if = "UnsignedDeviceInfo::is_empty")] pub unsigned: UnsignedDeviceInfo, #[serde(flatten)] diff --git a/crates/matrix-sdk-crypto/src/types/mod.rs b/crates/matrix-sdk-crypto/src/types/mod.rs index c74ed9666..e891dbf1a 100644 --- a/crates/matrix-sdk-crypto/src/types/mod.rs +++ b/crates/matrix-sdk-crypto/src/types/mod.rs @@ -27,6 +27,159 @@ mod cross_signing_key; mod device_keys; mod one_time_keys; +use std::collections::BTreeMap; + pub use cross_signing_key::*; pub use device_keys::*; pub use one_time_keys::*; +use ruma::{DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceKeyId, OwnedUserId, UserId}; +use serde::{Deserialize, Serialize, Serializer}; +use vodozemac::Ed25519Signature; + +/// An enum over all the signature types. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Signature { + /// A Ed25519 digital signature. + Ed25519(Ed25519Signature), + /// An unknown digital signature as a base64 encoded string. + Other(String), +} + +impl Signature { + /// Get the Ed25519 signature, if this is one. + pub fn ed25519(&self) -> Option { + if let Self::Ed25519(signature) = &self { + Some(*signature) + } else { + None + } + } + + /// Convert the signature to a base64 encoded string. + pub fn to_base64(&self) -> String { + match self { + Signature::Ed25519(s) => s.to_base64(), + Signature::Other(s) => s.to_owned(), + } + } +} + +impl From for Signature { + fn from(signature: Ed25519Signature) -> Self { + Self::Ed25519(signature) + } +} + +/// Signatures for a signed object. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Signatures(BTreeMap>); + +impl Signatures { + /// Create a new, empty, signatures collection. + pub fn new() -> Self { + Signatures(Default::default()) + } + + /// Add the given signature from the given signer and the given key_id to + /// the collection. + pub fn add_signature( + &mut self, + signer: OwnedUserId, + key_id: OwnedDeviceKeyId, + signature: Ed25519Signature, + ) -> Option { + self.0.entry(signer).or_insert_with(Default::default).insert(key_id, signature.into()) + } + + /// Try to find an Ed25519 signature from the given signer with the given + /// key id. + pub fn get_signature(&self, signer: &UserId, key_id: &DeviceKeyId) -> Option { + self.get(signer)?.get(key_id)?.ed25519() + } + + /// Get the map of signatures that belong to the given user. + pub fn get(&self, signer: &UserId) -> Option<&BTreeMap> { + self.0.get(signer) + } + + /// Remove all the signatures we currently hold. + pub fn clear(&mut self) { + self.0.clear() + } + + /// Do we hold any signatures or is our collection completely empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// How many signatures do we currently hold. + pub fn signature_count(&self) -> usize { + self.0.iter().map(|(_, u)| u.len()).sum() + } +} + +impl Default for Signatures { + fn default() -> Self { + Self::new() + } +} + +impl IntoIterator for Signatures { + type Item = (OwnedUserId, BTreeMap); + + type IntoIter = + std::collections::btree_map::IntoIter>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'de> Deserialize<'de> for Signatures { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let map: BTreeMap> = + serde::Deserialize::deserialize(deserializer)?; + + let map = map + .into_iter() + .map(|(user, signatures)| { + let signatures = signatures + .into_iter() + .map(|(key_id, s)| { + let algorithm = key_id.algorithm(); + let signature = match algorithm { + DeviceKeyAlgorithm::Ed25519 => Ed25519Signature::from_base64(&s) + .map_err(serde::de::Error::custom)? + .into(), + _ => Signature::Other(s), + }; + + Ok((key_id, signature)) + }) + .collect::, _>>()?; + + Ok((user, signatures)) + }) + .collect::>()?; + + Ok(Signatures(map)) + } +} + +impl Serialize for Signatures { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let signatures: BTreeMap<&OwnedUserId, BTreeMap<&OwnedDeviceKeyId, String>> = self + .0 + .iter() + .map(|(u, m)| (u, m.iter().map(|(d, s)| (d, s.to_base64())).collect())) + .collect(); + + serde::Serialize::serialize(&signatures, serializer) + } +} diff --git a/crates/matrix-sdk-crypto/src/types/one_time_keys.rs b/crates/matrix-sdk-crypto/src/types/one_time_keys.rs index f3c609535..6e2452fa1 100644 --- a/crates/matrix-sdk-crypto/src/types/one_time_keys.rs +++ b/crates/matrix-sdk-crypto/src/types/one_time_keys.rs @@ -20,24 +20,22 @@ use std::collections::BTreeMap; -use ruma::{serde::Raw, OwnedDeviceKeyId, OwnedUserId}; +use ruma::serde::Raw; use serde::{Deserialize, Serialize, Serializer}; use serde_json::{value::to_raw_value, Value}; -use vodozemac::{Curve25519PublicKey, Ed25519Signature}; +use vodozemac::Curve25519PublicKey; -/// Signatures for a `SignedKey` object. -pub type SignedKeySignatures = BTreeMap>; +use super::Signatures; /// A key for the SignedCurve25519 algorithm #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct SignedKey { - // /// The Curve25519 key that can be used to establish Olm sessions. + /// The Curve25519 key that can be used to establish Olm sessions. #[serde(deserialize_with = "deserialize_curve_key", serialize_with = "serialize_curve_key")] key: Curve25519PublicKey, /// Signatures for the key object. - #[serde(deserialize_with = "deserialize_signatures", serialize_with = "serialize_signatures")] - signatures: SignedKeySignatures, + signatures: Signatures, /// Is the key considered to be a fallback key. #[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "double_option")] @@ -67,42 +65,6 @@ where s.serialize_str(&key) } -fn deserialize_signatures<'de, D>(de: D) -> Result -where - D: serde::Deserializer<'de>, -{ - let map: BTreeMap> = - Deserialize::deserialize(de)?; - - map.into_iter() - .map(|(u, m)| { - Ok(( - u, - m.into_iter() - .map(|(d, s)| { - Ok(( - d, - Ed25519Signature::from_base64(&s).map_err(serde::de::Error::custom)?, - )) - }) - .collect::, _>>()?, - )) - }) - .collect::>() -} - -fn serialize_signatures(signatures: &SignedKeySignatures, s: S) -> Result -where - S: Serializer, -{ - let signatures: BTreeMap<&OwnedUserId, BTreeMap<&OwnedDeviceKeyId, String>> = signatures - .iter() - .map(|(u, m)| (u, m.iter().map(|(d, s)| (d, s.to_base64())).collect())) - .collect(); - - signatures.serialize(s) -} - fn double_option<'de, T, D>(de: D) -> Result>, D::Error> where T: Deserialize<'de>, @@ -114,7 +76,7 @@ where impl SignedKey { /// Creates a new `SignedKey` with the given key and signatures. pub fn new(key: Curve25519PublicKey) -> Self { - Self { key, signatures: BTreeMap::new(), fallback: None, other: BTreeMap::new() } + Self { key, signatures: Signatures::new(), fallback: None, other: BTreeMap::new() } } /// Creates a new `SignedKey`, that represents a fallback key, with the @@ -122,7 +84,7 @@ impl SignedKey { pub fn new_fallback(key: Curve25519PublicKey) -> Self { Self { key, - signatures: BTreeMap::new(), + signatures: Signatures::new(), fallback: Some(Some(true)), other: BTreeMap::new(), } @@ -134,7 +96,12 @@ impl SignedKey { } /// Signatures for the key object. - pub fn signatures(&mut self) -> &mut SignedKeySignatures { + pub fn signatures(&self) -> &Signatures { + &self.signatures + } + + /// Signatures for the key object as a mutable borrow. + pub fn signatures_mut(&mut self) -> &mut Signatures { &mut self.signatures } @@ -168,18 +135,24 @@ pub enum OneTimeKey { #[cfg(test)] mod tests { use matches::assert_matches; + use ruma::{device_id, user_id, DeviceKeyAlgorithm, DeviceKeyId}; use serde_json::json; - use vodozemac::Curve25519PublicKey; + use vodozemac::{Curve25519PublicKey, Ed25519Signature}; use super::OneTimeKey; + use crate::types::Signature; #[test] fn serialization() { + let user_id = user_id!("@user:example.com"); + let device_id = device_id!("EGURVBUNJP"); + let json = json!({ "key":"XjhWTCjW7l59pbfx9tlCBQolfnIQWARoKOzjTOPSlWM", "signatures": { - "@user:example.com": { - "ed25519:EGURVBUNJP": "mia28GKixFzOWKJ0h7Bdrdy2fjxiHCsst1qpe467FbW85H61UlshtKBoAXfTLlVfi0FX+/noJ8B3noQPnY+9Cg" + user_id: { + "ed25519:EGURVBUNJP": "mia28GKixFzOWKJ0h7Bdrdy2fjxiHCsst1qpe467FbW85H61UlshtKBoAXfTLlVfi0FX+/noJ8B3noQPnY+9Cg", + "other:EGURVBUNJP": "UnknownSignature" } }, "extra_key": "extra_value" @@ -189,10 +162,27 @@ mod tests { Curve25519PublicKey::from_base64("XjhWTCjW7l59pbfx9tlCBQolfnIQWARoKOzjTOPSlWM") .expect("Can't construct curve key from base64"); + let signature = Ed25519Signature::from_base64( + "mia28GKixFzOWKJ0h7Bdrdy2fjxiHCsst1qpe467FbW85H61UlshtKBoAXfTLlVfi0FX+/noJ8B3noQPnY+9Cg" + ).expect("The signature can always be decoded"); + + let custom_signature = Signature::Other("UnknownSignature".to_owned()); + let key: OneTimeKey = serde_json::from_value(json.clone()).expect("Can't deserialize a valid one-time key"); - assert_matches!(key, OneTimeKey::SignedKey(ref k) if k.key == curve_key); + assert_matches!( + key, + OneTimeKey::SignedKey(ref k) if k.key == curve_key + && k + .signatures() + .get_signature(user_id, &DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, device_id)) + == Some(signature) + && k + .signatures() + .get(user_id).unwrap().get(&DeviceKeyId::from_parts("other".into(), device_id)) + == Some(&custom_signature) + ); let serialized = serde_json::to_value(key).expect("Can't reserialize a signed key"); diff --git a/crates/matrix-sdk-crypto/src/verification/mod.rs b/crates/matrix-sdk-crypto/src/verification/mod.rs index 496ca7188..672129cf2 100644 --- a/crates/matrix-sdk-crypto/src/verification/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/mod.rs @@ -20,10 +20,7 @@ mod qrcode; mod requests; mod sas; -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; +use std::{collections::HashMap, sync::Arc}; use event_enums::OutgoingContent; pub use machine::VerificationMachine; @@ -47,8 +44,8 @@ use ruma::{ }, AnyMessageLikeEventContent, AnyToDeviceEventContent, }, - DeviceId, EventId, OwnedDeviceId, OwnedDeviceKeyId, OwnedEventId, OwnedRoomId, - OwnedTransactionId, OwnedUserId, RoomId, UserId, + DeviceId, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedTransactionId, RoomId, + UserId, }; pub use sas::{AcceptSettings, Sas}; use tracing::{error, info, trace, warn}; @@ -58,6 +55,7 @@ use crate::{ gossiping::{GossipMachine, GossipRequest}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount, Session}, store::{Changes, CryptoStore}, + types::Signatures, CryptoStoreError, LocalTrust, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, }; @@ -140,10 +138,7 @@ impl VerificationStore { } /// Get the signatures that have signed our own device. - pub async fn device_signatures( - &self, - ) -> Result>>, CryptoStoreError> - { + pub async fn device_signatures(&self) -> Result, CryptoStoreError> { Ok(self .inner .get_device(self.account.user_id(), self.account.device_id()) diff --git a/crates/matrix-sdk-ffi/src/client.rs b/crates/matrix-sdk-ffi/src/client.rs index ef7cf5e0d..b8b7e056a 100644 --- a/crates/matrix-sdk-ffi/src/client.rs +++ b/crates/matrix-sdk-ffi/src/client.rs @@ -106,7 +106,7 @@ impl Client { pub fn restore_token(&self) -> anyhow::Result { RUNTIME.block_on(async move { - let session = self.client.session().await.expect("Missing session"); + let session = self.client.session().expect("Missing session").clone(); let homeurl = self.client.homeserver().await.into(); Ok(serde_json::to_string(&RestoreToken { session, @@ -121,11 +121,8 @@ impl Client { } pub fn user_id(&self) -> anyhow::Result { - let l = self.client.clone(); - RUNTIME.block_on(async move { - let user_id = l.user_id().await.expect("No User ID found"); - Ok(user_id.to_string()) - }) + let user_id = self.client.user_id().expect("No User ID found"); + Ok(user_id.to_string()) } pub fn display_name(&self) -> anyhow::Result { @@ -145,11 +142,8 @@ impl Client { } pub fn device_id(&self) -> anyhow::Result { - let l = self.client.clone(); - RUNTIME.block_on(async move { - let device_id = l.device_id().await.expect("No Device ID found"); - Ok(device_id.to_string()) - }) + let device_id = self.client.device_id().expect("No Device ID found"); + Ok(device_id.to_string()) } pub fn get_media_content(&self, media_source: Arc) -> anyhow::Result> { diff --git a/crates/matrix-sdk-indexeddb/Cargo.toml b/crates/matrix-sdk-indexeddb/Cargo.toml index 15e19d61d..6ec6f9b99 100644 --- a/crates/matrix-sdk-indexeddb/Cargo.toml +++ b/crates/matrix-sdk-indexeddb/Cargo.toml @@ -8,13 +8,16 @@ edition = "2021" rust-version = "1.60" readme = "README.md" +[package.metadata.docs.rs] +all-features = true +default-target = "wasm32-unknown-unknown" +rustdoc-args = ["--cfg", "docsrs"] + [features] default = ["e2e-encryption"] e2e-encryption = ["matrix-sdk-base/e2e-encryption", "matrix-sdk-crypto"] experimental-timeline = ["matrix-sdk-base/experimental-timeline"] -[package.metadata.docs.rs] -default-target = "wasm32-unknown-unknown" [dependencies] anyhow = "1.0.57" diff --git a/crates/matrix-sdk-qrcode/Cargo.toml b/crates/matrix-sdk-qrcode/Cargo.toml index 417afaade..5cb5f1de7 100644 --- a/crates/matrix-sdk-qrcode/Cargo.toml +++ b/crates/matrix-sdk-qrcode/Cargo.toml @@ -29,4 +29,5 @@ ruma-common = "0.9.0" thiserror = "1.0.30" [dependencies.vodozemac] -version = "0.2.0" +git = "https://github.com/matrix-org/vodozemac/" +rev = "d0e744287a14319c2a9148fef3747548c740fc36" diff --git a/crates/matrix-sdk-qrcode/src/types.rs b/crates/matrix-sdk-qrcode/src/types.rs index a94e7f5d8..26dc4ef25 100644 --- a/crates/matrix-sdk-qrcode/src/types.rs +++ b/crates/matrix-sdk-qrcode/src/types.rs @@ -32,7 +32,7 @@ use crate::{ }; /// An enum representing the different modes a QR verification can be in. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum QrVerificationData { /// The QR verification is verifying another user Verification(VerificationData), @@ -375,7 +375,7 @@ impl QrVerificationData { /// /// This mode is used for verification between two users using their master /// cross signing keys. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct VerificationData { event_id: OwnedEventId, first_master_key: Ed25519PublicKey, @@ -474,7 +474,7 @@ impl From for QrVerificationData { /// This mode is used for verification between two devices of the same user /// where this device, that is creating this QR code, is trusting or owning /// the cross signing master key. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct SelfVerificationData { transaction_id: String, master_key: Ed25519PublicKey, @@ -577,7 +577,7 @@ impl From for QrVerificationData { /// This mode is used for verification between two devices of the same user /// where this device, that is creating this QR code, is not trusting the /// cross signing master key. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct SelfVerificationNoMasterKey { transaction_id: String, device_key: Ed25519PublicKey, diff --git a/crates/matrix-sdk-sled/Cargo.toml b/crates/matrix-sdk-sled/Cargo.toml index fec6b2c3e..787935531 100644 --- a/crates/matrix-sdk-sled/Cargo.toml +++ b/crates/matrix-sdk-sled/Cargo.toml @@ -9,6 +9,10 @@ license = "Apache-2.0" rust-version = "1.60" readme = "README.md" +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + [features] default = ["state-store"] diff --git a/crates/matrix-sdk-store-encryption/Cargo.toml b/crates/matrix-sdk-store-encryption/Cargo.toml index 370c88eff..935999454 100644 --- a/crates/matrix-sdk-store-encryption/Cargo.toml +++ b/crates/matrix-sdk-store-encryption/Cargo.toml @@ -7,6 +7,9 @@ repository = "https://github.com/matrix-org/matrix-rust-sdk" license = "Apache-2.0" rust-version = "1.60" +[package.metadata.docs.rs] +rustdoc-args = ["--cfg", "docsrs"] + [features] js = ["getrandom/js"] diff --git a/crates/matrix-sdk-test/src/test_json/mod.rs b/crates/matrix-sdk-test/src/test_json/mod.rs index 78f440b2e..2f6c9cc5c 100644 --- a/crates/matrix-sdk-test/src/test_json/mod.rs +++ b/crates/matrix-sdk-test/src/test_json/mod.rs @@ -46,6 +46,17 @@ pub static DEVICES: Lazy = Lazy::new(|| { }) }); +pub static GET_ALIAS: Lazy = Lazy::new(|| { + json!({ + "room_id": "!lUbmUPdxdXxEQurqOs:example.net", + "servers": [ + "example.org", + "example.net", + "matrix.org", + ] + }) +}); + pub static WELL_KNOWN: Lazy = Lazy::new(|| { json!({ "m.homeserver": { diff --git a/crates/matrix-sdk/examples/autojoin.rs b/crates/matrix-sdk/examples/autojoin.rs index 158e82aeb..e641f80b3 100644 --- a/crates/matrix-sdk/examples/autojoin.rs +++ b/crates/matrix-sdk/examples/autojoin.rs @@ -10,7 +10,7 @@ async fn on_stripped_state_member( client: Client, room: Room, ) { - if room_member.state_key != client.user_id().await.unwrap() { + if room_member.state_key != client.user_id().unwrap() { return; } diff --git a/crates/matrix-sdk/src/account.rs b/crates/matrix-sdk/src/account.rs index b195a1f0c..97a805381 100644 --- a/crates/matrix-sdk/src/account.rs +++ b/crates/matrix-sdk/src/account.rs @@ -70,8 +70,8 @@ impl Account { /// # anyhow::Ok(()) }); /// ``` pub async fn get_display_name(&self) -> Result> { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; - let request = get_display_name::v3::Request::new(&user_id); + let user_id = self.client.user_id().ok_or(Error::AuthenticationRequired)?; + let request = get_display_name::v3::Request::new(user_id); let response = self.client.send(request, None).await?; Ok(response.displayname) } @@ -93,8 +93,8 @@ impl Account { /// # anyhow::Ok(()) }); /// ``` pub async fn set_display_name(&self, name: Option<&str>) -> Result<()> { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; - let request = set_display_name::v3::Request::new(&user_id, name); + let user_id = self.client.user_id().ok_or(Error::AuthenticationRequired)?; + let request = set_display_name::v3::Request::new(user_id, name); self.client.send(request, None).await?; Ok(()) } @@ -118,8 +118,8 @@ impl Account { /// # anyhow::Ok(()) }); /// ``` pub async fn get_avatar_url(&self) -> Result> { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; - let request = get_avatar_url::v3::Request::new(&user_id); + let user_id = self.client.user_id().ok_or(Error::AuthenticationRequired)?; + let request = get_avatar_url::v3::Request::new(user_id); let config = Some(RequestConfig::new().force_auth()); @@ -131,8 +131,8 @@ impl Account { /// /// The avatar is unset if `url` is `None`. pub async fn set_avatar_url(&self, url: Option<&MxcUri>) -> Result<()> { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; - let request = set_avatar_url::v3::Request::new(&user_id, url); + let user_id = self.client.user_id().ok_or(Error::AuthenticationRequired)?; + let request = set_avatar_url::v3::Request::new(user_id, url); self.client.send(request, None).await?; Ok(()) } @@ -233,8 +233,8 @@ impl Account { /// # anyhow::Ok(()) }); /// ``` pub async fn get_profile(&self) -> Result { - let user_id = self.client.user_id().await.ok_or(Error::AuthenticationRequired)?; - let request = get_profile::v3::Request::new(&user_id); + let user_id = self.client.user_id().ok_or(Error::AuthenticationRequired)?; + let request = get_profile::v3::Request::new(user_id); Ok(self.client.send(request, None).await?) } diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index 479c344ac..f97c2ea21 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -295,12 +295,7 @@ impl ClientBuilder { let base_client = BaseClient::with_store_config(self.store_config); let mk_http_client = |homeserver| { - HttpClient::new( - inner_http_client.clone(), - homeserver, - base_client.session().clone(), - self.request_config, - ) + HttpClient::new(inner_http_client.clone(), homeserver, self.request_config) }; let homeserver = match homeserver_cfg { diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 1691ae2b4..81c8fda52 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -43,6 +43,7 @@ use ruma::{ api::{ client::{ account::{register, whoami}, + alias::get_alias, device::{delete_devices, get_devices}, directory::{get_public_rooms, get_public_rooms_filtered}, discovery::{ @@ -64,8 +65,8 @@ use ruma::{ assign, events::room::MediaSource, presence::PresenceState, - MxcUri, OwnedDeviceId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, RoomOrAliasId, - ServerName, UInt, + DeviceId, MxcUri, OwnedDeviceId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, + RoomOrAliasId, ServerName, UInt, UserId, }; use serde::de::DeserializeOwned; #[cfg(not(target_arch = "wasm32"))] @@ -195,8 +196,8 @@ impl Client { } #[cfg(feature = "e2e-encryption")] - pub(crate) async fn olm_machine(&self) -> Option { - self.base_client().olm_machine().await + pub(crate) fn olm_machine(&self) -> Option<&matrix_sdk_base::crypto::OlmMachine> { + self.base_client().olm_machine() } #[cfg(feature = "e2e-encryption")] @@ -266,8 +267,8 @@ impl Client { } /// Is the client logged in. - pub async fn logged_in(&self) -> bool { - self.inner.base_client.logged_in().await + pub fn logged_in(&self) -> bool { + self.inner.base_client.logged_in() } /// The Homeserver of the client. @@ -276,15 +277,13 @@ impl Client { } /// Get the user id of the current owner of the client. - pub async fn user_id(&self) -> Option { - let session = self.inner.base_client.session().read().await; - session.as_ref().cloned().map(|s| s.user_id) + pub fn user_id(&self) -> Option<&UserId> { + self.inner.base_client.session().map(|s| s.user_id.as_ref()) } /// Get the device id that identifies the current session. - pub async fn device_id(&self) -> Option { - let session = self.inner.base_client.session().read().await; - session.as_ref().map(|s| s.device_id.clone()) + pub fn device_id(&self) -> Option<&DeviceId> { + self.inner.base_client.session().map(|s| s.device_id.as_ref()) } /// Get the whole session info of this client. @@ -293,8 +292,8 @@ impl Client { /// /// Can be used with [`Client::restore_login`] to restore a previously /// logged in session. - pub async fn session(&self) -> Option { - self.inner.base_client.session().read().await.clone() + pub fn session(&self) -> Option<&Session> { + self.inner.base_client.session() } /// Get a reference to the store. @@ -608,6 +607,20 @@ impl Client { self.store().get_room(room_id).and_then(|room| room::Left::new(self.clone(), room)) } + /// Resolve a room alias to a room id and a list of servers which know + /// about it. + /// + /// # Arguments + /// + /// `room_alias` - The room alias to be resolved. + pub async fn resolve_room_alias( + &self, + room_alias: &RoomAliasId, + ) -> HttpResult { + let request = get_alias::v3::Request::new(room_alias); + self.send(request, None).await + } + /// Gets the homeserver’s supported login types. /// /// This should be the first step when trying to login so you can call the @@ -1099,6 +1112,7 @@ impl Client { /// /// [`login`]: #method.login pub async fn restore_login(&self, session: Session) -> Result<()> { + self.inner.http_client.set_session(session.clone()); Ok(self.inner.base_client.restore_login(session).await?) } @@ -1210,8 +1224,8 @@ impl Client { if let Some(filter) = self.inner.base_client.get_filter(filter_name).await? { Ok(filter) } else { - let user_id = self.user_id().await.ok_or(Error::AuthenticationRequired)?; - let request = FilterUploadRequest::new(&user_id, definition); + let user_id = self.user_id().ok_or(Error::AuthenticationRequired)?; + let request = FilterUploadRequest::new(user_id, definition); let response = self.send(request, None).await?; self.inner.base_client.receive_filter_upload(filter_name, &response).await?; @@ -2407,7 +2421,7 @@ pub(crate) mod tests { client.login("example", "wordpass", None, None).await.unwrap(); - let logged_in = client.logged_in().await; + let logged_in = client.logged_in(); assert!(logged_in, "Client should be logged in"); assert_eq!(client.homeserver().await, homeserver); @@ -2424,7 +2438,7 @@ pub(crate) mod tests { client.login("example", "wordpass", None, None).await.unwrap(); - let logged_in = client.logged_in().await; + let logged_in = client.logged_in(); assert!(logged_in, "Client should be logged in"); assert_eq!(client.homeserver().await.as_str(), "https://example.org/"); @@ -2441,7 +2455,7 @@ pub(crate) mod tests { client.login("example", "wordpass", None, None).await.unwrap(); - let logged_in = client.logged_in().await; + let logged_in = client.logged_in(); assert!(logged_in, "Client should be logged in"); assert_eq!(client.homeserver().await, Url::parse(&mockito::server_url()).unwrap()); @@ -2485,7 +2499,7 @@ pub(crate) mod tests { .await .unwrap(); - let logged_in = client.logged_in().await; + let logged_in = client.logged_in(); assert!(logged_in, "Client should be logged in"); } @@ -2517,7 +2531,7 @@ pub(crate) mod tests { client.login_with_token("averysmalltoken", None, None).await.unwrap(); - let logged_in = client.logged_in().await; + let logged_in = client.logged_in(); assert!(logged_in, "Client should be logged in"); } @@ -2533,6 +2547,19 @@ pub(crate) mod tests { assert!(client.devices().await.is_ok()); } + #[async_test] + async fn resolve_room_alias() { + let client = no_retry_test_client().await; + + let _m = mock("GET", "/_matrix/client/r0/directory/room/%23alias%3Aexample%2Eorg") + .with_status(200) + .with_body(test_json::GET_ALIAS.to_string()) + .create(); + + let alias = ruma::room_alias_id!("#alias:example.org"); + assert!(client.resolve_room_alias(alias).await.is_ok()); + } + #[async_test] async fn test_join_leave_room() { let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost"); @@ -2543,7 +2570,7 @@ pub(crate) mod tests { .create(); let client = logged_in_client().await; - let session = client.session().await.unwrap(); + let session = client.session().unwrap().clone(); let room = client.get_joined_room(room_id); assert!(room.is_none()); diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index 16cbfddd6..109d8d5b1 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -213,9 +213,9 @@ impl Client { T: GlobalAccountDataEventContent, { let own_user = - self.user_id().await.ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?; + self.user_id().ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?; - let request = set_global_account_data::v3::Request::new(&content, &own_user)?; + let request = set_global_account_data::v3::Request::new(&content, own_user)?; Ok(self.send(request, None).await?) } @@ -473,7 +473,7 @@ impl Encryption { /// Get the public ed25519 key of our own device. This is usually what is /// called the fingerprint of the device. pub async fn ed25519_key(&self) -> Option { - self.client.olm_machine().await.map(|o| o.identity_keys().ed25519.to_base64()) + self.client.olm_machine().map(|o| o.identity_keys().ed25519.to_base64()) } /// Get the status of the private cross signing keys. @@ -481,7 +481,7 @@ impl Encryption { /// This can be used to check which private cross signing keys we have /// stored locally. pub async fn cross_signing_status(&self) -> Option { - if let Some(machine) = self.client.olm_machine().await { + if let Some(machine) = self.client.olm_machine() { Some(machine.cross_signing_status().await) } else { None @@ -493,12 +493,12 @@ impl Encryption { /// Tracked users are users for which we keep the device list of E2EE /// capable devices up to date. pub async fn tracked_users(&self) -> HashSet { - self.client.olm_machine().await.map(|o| o.tracked_users()).unwrap_or_default() + self.client.olm_machine().map(|o| o.tracked_users()).unwrap_or_default() } /// Get a verification object with the given flow id. pub async fn get_verification(&self, user_id: &UserId, flow_id: &str) -> Option { - let olm = self.client.olm_machine().await?; + let olm = self.client.olm_machine()?; #[allow(clippy::bind_instead_of_map)] olm.get_verification(user_id, flow_id).and_then(|v| match v { matrix_sdk_base::crypto::Verification::SasV1(s) => { @@ -519,7 +519,7 @@ impl Encryption { user_id: &UserId, flow_id: impl AsRef, ) -> Option { - let olm = self.client.olm_machine().await?; + let olm = self.client.olm_machine()?; olm.get_verification_request(user_id, flow_id) .map(|r| VerificationRequest { inner: r, client: self.client.clone() }) @@ -644,7 +644,7 @@ impl Encryption { ) -> Result, CryptoStoreError> { use crate::encryption::identities::UserIdentity; - if let Some(olm) = self.client.olm_machine().await { + if let Some(olm) = self.client.olm_machine() { let identity = olm.get_identity(user_id).await?; Ok(identity.map(|i| match i { @@ -705,7 +705,7 @@ impl Encryption { /// } /// # anyhow::Result::<()>::Ok(()) }); pub async fn bootstrap_cross_signing(&self, auth_data: Option>) -> Result<()> { - let olm = self.client.olm_machine().await.ok_or(Error::AuthenticationRequired)?; + let olm = self.client.olm_machine().ok_or(Error::AuthenticationRequired)?; let (request, signature_request) = olm.bootstrap_cross_signing(false).await?; @@ -782,7 +782,7 @@ impl Encryption { passphrase: &str, predicate: impl FnMut(&matrix_sdk_base::crypto::olm::InboundGroupSession) -> bool, ) -> Result<()> { - let olm = self.client.olm_machine().await.ok_or(Error::AuthenticationRequired)?; + let olm = self.client.olm_machine().ok_or(Error::AuthenticationRequired)?; let keys = olm.export_keys(predicate).await?; let passphrase = zeroize::Zeroizing::new(passphrase.to_owned()); @@ -842,7 +842,7 @@ impl Encryption { path: PathBuf, passphrase: &str, ) -> Result { - let olm = self.client.olm_machine().await.ok_or(RoomKeyImportError::StoreClosed)?; + let olm = self.client.olm_machine().ok_or(RoomKeyImportError::StoreClosed)?; let passphrase = zeroize::Zeroizing::new(passphrase.to_owned()); let decrypt = move || { diff --git a/crates/matrix-sdk/src/http_client.rs b/crates/matrix-sdk/src/http_client.rs index 9161e76a1..1ebb5f502 100644 --- a/crates/matrix-sdk/src/http_client.rs +++ b/crates/matrix-sdk/src/http_client.rs @@ -17,6 +17,7 @@ use std::{any::type_name, convert::TryFrom, fmt::Debug, sync::Arc, time::Duratio use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use http::Response as HttpResponse; +use matrix_sdk_base::once_cell::sync::OnceCell; use matrix_sdk_common::{locks::RwLock, AsyncTraitDeps}; use reqwest::Response; use ruma::api::{ @@ -95,7 +96,7 @@ pub trait HttpSend: AsyncTraitDeps { pub(crate) struct HttpClient { pub(crate) inner: Arc, pub(crate) homeserver: Arc>, - pub(crate) session: Arc>>, + pub(crate) session: OnceCell, pub(crate) request_config: RequestConfig, } @@ -103,10 +104,9 @@ impl HttpClient { pub(crate) fn new( inner: Arc, homeserver: Arc>, - session: Arc>>, request_config: RequestConfig, ) -> Self { - HttpClient { inner, homeserver, session, request_config } + HttpClient { inner, homeserver, session: Default::default(), request_config } } #[tracing::instrument(skip(self, request), fields(request_type = type_name::()))] @@ -130,21 +130,18 @@ impl HttpClient { return Err(HttpError::NotClientRequest); } - let access_token; - let request = if !self.request_config.assert_identity { let send_access_token = if auth_scheme == AuthScheme::None && !config.force_auth { // Small optimization: Don't take the session lock if we know the auth token // isn't going to be used anyways. SendAccessToken::None } else { - match self.session.read().await.as_ref() { + match self.session() { Some(session) => { - access_token = session.access_token.clone(); if config.force_auth { - SendAccessToken::Always(&access_token) + SendAccessToken::Always(&session.access_token) } else { - SendAccessToken::IfRequired(&access_token) + SendAccessToken::IfRequired(&session.access_token) } } None => SendAccessToken::None, @@ -157,18 +154,12 @@ impl HttpClient { &server_versions, )? } else { - let (send_access_token, user_id) = { - let session = self.session.read().await; - let session = session.as_ref().ok_or(HttpError::UserIdRequired)?; - - access_token = session.access_token.clone(); - (SendAccessToken::Always(&access_token), session.user_id.clone()) - }; - request.try_into_http_request_with_user_id::( &self.homeserver.read().await.to_string(), - send_access_token, - &user_id, + SendAccessToken::Always( + &self.session().ok_or(HttpError::UserIdRequired)?.access_token, + ), + &self.session().ok_or(HttpError::UserIdRequired)?.user_id, &server_versions, )? }; @@ -182,6 +173,14 @@ impl HttpClient { Ok(response) } + + pub(crate) fn set_session(&self, session: Session) { + self.session.set(session).expect("A session was already set"); + } + + fn session(&self) -> Option<&Session> { + self.session.get() + } } #[derive(Debug)] diff --git a/crates/matrix-sdk/src/room/common.rs b/crates/matrix-sdk/src/room/common.rs index a78a0f7b2..31c56a3a3 100644 --- a/crates/matrix-sdk/src/room/common.rs +++ b/crates/matrix-sdk/src/room/common.rs @@ -203,7 +203,7 @@ impl Common { }; #[cfg(feature = "e2e-encryption")] - if let Some(machine) = self.client.olm_machine().await { + if let Some(machine) = self.client.olm_machine() { for event in http_response.chunk { let decrypted_event = if let Ok(AnySyncRoomEvent::MessageLike( AnySyncMessageLikeEvent::RoomEncrypted(SyncMessageLikeEvent::Original( @@ -852,9 +852,9 @@ impl Common { tag: TagName, tag_info: TagInfo, ) -> HttpResult { - let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?; + let user_id = self.client.user_id().ok_or(HttpError::AuthenticationRequired)?; let request = - create_tag::v3::Request::new(&user_id, self.inner.room_id(), tag.as_ref(), tag_info); + create_tag::v3::Request::new(user_id, self.inner.room_id(), tag.as_ref(), tag_info); self.client.send(request, None).await } @@ -865,8 +865,8 @@ impl Common { /// # Arguments /// * `tag` - The tag to remove. pub async fn remove_tag(&self, tag: TagName) -> HttpResult { - let user_id = self.client.user_id().await.ok_or(HttpError::AuthenticationRequired)?; - let request = delete_tag::v3::Request::new(&user_id, self.inner.room_id(), tag.as_ref()); + let user_id = self.client.user_id().ok_or(HttpError::AuthenticationRequired)?; + let request = delete_tag::v3::Request::new(user_id, self.inner.room_id(), tag.as_ref()); self.client.send(request, None).await } @@ -879,11 +879,8 @@ impl Common { /// # Arguments /// * `is_direct` - Whether to mark this room as direct. pub async fn set_is_direct(&self, is_direct: bool) -> Result<()> { - let user_id = self - .client - .user_id() - .await - .ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?; + let user_id = + self.client.user_id().ok_or_else(|| Error::from(HttpError::AuthenticationRequired))?; let mut content = self .client @@ -914,7 +911,7 @@ impl Common { content.retain(|_, list| !list.is_empty()); } - let request = set_global_account_data::v3::Request::new(&content, &user_id)?; + let request = set_global_account_data::v3::Request::new(&content, user_id)?; self.client.send(request, None).await?; Ok(()) @@ -928,7 +925,7 @@ impl Common { /// Returns the decrypted event. #[cfg(feature = "e2e-encryption")] pub async fn decrypt_event(&self, event: &OriginalSyncRoomEncryptedEvent) -> Result { - if let Some(machine) = self.client.olm_machine().await { + if let Some(machine) = self.client.olm_machine() { Ok(machine.decrypt_room_event(event, self.inner.room_id()).await?) } else { Err(Error::NoOlmMachine) diff --git a/crates/matrix-sdk/src/room/joined.rs b/crates/matrix-sdk/src/room/joined.rs index b65455119..da739d41a 100644 --- a/crates/matrix-sdk/src/room/joined.rs +++ b/crates/matrix-sdk/src/room/joined.rs @@ -572,7 +572,7 @@ impl Joined { self.preshare_group_session().await?; - let olm = self.client.olm_machine().await.expect("Olm machine wasn't started"); + let olm = self.client.olm_machine().expect("Olm machine wasn't started"); let encrypted_content = olm.encrypt_room_event_raw(self.inner.room_id(), content, event_type).await?;