diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index e2f2a7f86..8bb2f8194 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -17,17 +17,9 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = [] -encryption = ["matrix-sdk-crypto", "store_key"] +encryption = ["matrix-sdk-crypto"] qrcode = ["matrix-sdk-crypto/qrcode"] -# Store Key is a helper feature for state store implementations -store_key = [ - "pbkdf2", - "hmac", - "sha2", - "rand", - "chacha20poly1305", -] # helpers for testing features build upon this testing = [ "http" ] diff --git a/crates/matrix-sdk-base/README.md b/crates/matrix-sdk-base/README.md index e5c1ecf4f..af30e9be3 100644 --- a/crates/matrix-sdk-base/README.md +++ b/crates/matrix-sdk-base/README.md @@ -7,5 +7,4 @@ The following crate feature flags are available: * `encryption`: Enables end-to-end encryption support in the library. * `qrcode`: Enbles QRcode generation and reading code -* `store_key`: Provides extra facilities for StateStores to manage store keys * `testing`: provides facilities and functions for tests, in particular for integration testing store implementations. ATTENTION: do not ever use outside of tests, we do not provide any stability warantees on these, these are merely helpers. If you find you _need_ any function provided here outside of tests, please open a Github Issue and inform us about your use case for us to consider. \ No newline at end of file diff --git a/crates/matrix-sdk-base/src/store/integration_tests.rs b/crates/matrix-sdk-base/src/store/integration_tests.rs index dca117cae..eacd563ed 100644 --- a/crates/matrix-sdk-base/src/store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/store/integration_tests.rs @@ -277,18 +277,18 @@ macro_rules! statestore_integration_tests { assert!(store.get_sync_token().await?.is_some()); assert!(store.get_presence_event(user_id).await?.is_some()); - assert_eq!(store.get_room_infos().await?.len(), 1); - assert_eq!(store.get_stripped_room_infos().await?.len(), 1); + assert_eq!(store.get_room_infos().await?.len(), 1, "Expected to find 1 room info "); + assert_eq!(store.get_stripped_room_infos().await?.len(), 1, "Expected to find 1 stripped room info"); assert!(store.get_account_data_event(GlobalAccountDataEventType::PushRules).await?.is_some()); assert!(store.get_state_event(room_id, StateEventType::RoomName, "").await?.is_some()); - assert_eq!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), 1); + assert_eq!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), 1, "Expected to find 1 room topic"); assert!(store.get_profile(room_id, user_id).await?.is_some()); assert!(store.get_member_event(room_id, user_id).await?.is_some()); - assert_eq!(store.get_user_ids(room_id).await?.len(), 2); - assert_eq!(store.get_invited_user_ids(room_id).await?.len(), 1); - assert_eq!(store.get_joined_user_ids(room_id).await?.len(), 1); - assert_eq!(store.get_users_with_display_name(room_id, "example").await?.len(), 2); + assert_eq!(store.get_user_ids(room_id).await?.len(), 2, "Expected to find 2 members for room"); + assert_eq!(store.get_invited_user_ids(room_id).await?.len(), 1, "Expected to find 1 invited user ids"); + assert_eq!(store.get_joined_user_ids(room_id).await?.len(), 1, "Expected to find 1 joined user ids"); + assert_eq!(store.get_users_with_display_name(room_id, "example").await?.len(), 2, "Expected to find 2 display names for room"); assert!(store .get_room_account_data_event(room_id, RoomAccountDataEventType::Tag) .await? @@ -302,8 +302,7 @@ macro_rules! statestore_integration_tests { .get_event_room_receipt_events(room_id, ReceiptType::Read, first_receipt_event_id()) .await? .len(), - 1 - ); + 1, "Expected to find 1 read receipt"); Ok(()) } @@ -325,7 +324,7 @@ macro_rules! statestore_integration_tests { assert!(store.get_member_event(room_id, user_id).await.unwrap().is_some()); let members = store.get_user_ids(room_id).await.unwrap(); - assert!(!members.is_empty()) + assert!(!members.is_empty(), "We expected to find members for the room") } #[async_test] @@ -354,7 +353,7 @@ macro_rules! statestore_integration_tests { #[async_test] async fn test_receipts_saving() { - let store = get_store().await.unwrap(); + let store = get_store().await.expect("creating store failed"); let room_id = room_id!("!test_receipts_saving:localhost"); @@ -370,7 +369,7 @@ macro_rules! statestore_integration_tests { } } })) - .unwrap(); + .expect("json creation failed"); let second_receipt_event = serde_json::from_value(json!({ second_event_id.clone(): { @@ -381,68 +380,70 @@ macro_rules! statestore_integration_tests { } } })) - .unwrap(); + .expect("json creation failed"); assert!(store .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id()) .await - .unwrap() + .expect("failed to read user room receipt") .is_none()); assert!(store .get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id) .await - .unwrap() + .expect("failed to read user room receipt for 1") .is_empty()); assert!(store .get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id) .await - .unwrap() + .expect("failed to read user room receipt for 2") .is_empty()); let mut changes = StateChanges::default(); changes.add_receipts(room_id, first_receipt_event); - store.save_changes(&changes).await.unwrap(); + store.save_changes(&changes).await.expect("writing changes fauked"); assert!(store .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id()) .await - .unwrap() - .is_some(),); + .expect("failed to read user room receipt after save") + .is_some()); assert_eq!( store .get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id) .await - .unwrap() + .expect("failed to read user room receipt for 1 after save") .len(), - 1 + 1, + "Found a wrong number of receipts for 1 after save" ); assert!(store .get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id) .await - .unwrap() + .expect("failed to read user room receipt for 2 after save") .is_empty()); let mut changes = StateChanges::default(); changes.add_receipts(room_id, second_receipt_event); - store.save_changes(&changes).await.unwrap(); + store.save_changes(&changes).await.expect("Saving works"); assert!(store .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id()) .await - .unwrap() + .expect("Getting user room receipts failed") .is_some()); assert!(store .get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id) .await - .unwrap() + .expect("Getting event room receipt events for first event failed") .is_empty()); assert_eq!( store .get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id) .await - .unwrap() + .expect("Getting event room receipt events for second event failed") .len(), - 1 + 1, + "Found a wrong number of receipts for second event after save" ); } @@ -467,24 +468,24 @@ macro_rules! statestore_integration_tests { }), }; - assert!(store.get_media_content(&request_file).await.unwrap().is_none()); - assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + assert!(store.get_media_content(&request_file).await.unwrap().is_none(), "unexpected media found"); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none(), "media not found"); - store.add_media_content(&request_file, content.clone()).await.unwrap(); - assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + store.add_media_content(&request_file, content.clone()).await.expect("adding media failed"); + assert!(store.get_media_content(&request_file).await.unwrap().is_some(), "media not found though added"); - store.remove_media_content(&request_file).await.unwrap(); - assert!(store.get_media_content(&request_file).await.unwrap().is_none()); + store.remove_media_content(&request_file).await.expect("removing media failed"); + assert!(store.get_media_content(&request_file).await.unwrap().is_none(), "media still there after removing"); - store.add_media_content(&request_file, content.clone()).await.unwrap(); - assert!(store.get_media_content(&request_file).await.unwrap().is_some()); + store.add_media_content(&request_file, content.clone()).await.expect("adding media again failed"); + assert!(store.get_media_content(&request_file).await.unwrap().is_some(), "media not found after adding again"); - store.add_media_content(&request_thumbnail, content.clone()).await.unwrap(); - assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some()); + store.add_media_content(&request_thumbnail, content.clone()).await.expect("adding thumbnail failed"); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some(), "thumbnail not found"); - store.remove_media_content_for_uri(uri).await.unwrap(); - assert!(store.get_media_content(&request_file).await.unwrap().is_none()); - assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none()); + store.remove_media_content_for_uri(uri).await.expect("removing all media for uri failed"); + assert!(store.get_media_content(&request_file).await.unwrap().is_none(), "media wasn't removed"); + assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none(), "thumbnail wasn't removed"); } #[async_test] @@ -526,17 +527,17 @@ macro_rules! statestore_integration_tests { store.remove_room(room_id).await?; - assert_eq!(store.get_room_infos().await?.len(), 0); + assert!(store.get_room_infos().await?.is_empty(), "room is still there"); assert_eq!(store.get_stripped_room_infos().await?.len(), 1); assert!(store.get_state_event(room_id, StateEventType::RoomName, "").await?.is_none()); - assert_eq!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), 0); + assert!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.is_empty(), "still state events found"); assert!(store.get_profile(room_id, user_id).await?.is_none()); assert!(store.get_member_event(room_id, user_id).await?.is_none()); - assert_eq!(store.get_user_ids(room_id).await?.len(), 0); - assert_eq!(store.get_invited_user_ids(room_id).await?.len(), 0); - assert_eq!(store.get_joined_user_ids(room_id).await?.len(), 0); - assert_eq!(store.get_users_with_display_name(room_id, "example").await?.len(), 0); + assert!(store.get_user_ids(room_id).await?.is_empty(), "still user ids found"); + assert!(store.get_invited_user_ids(room_id).await?.is_empty(), "still invited user ids found"); + assert!(store.get_joined_user_ids(room_id).await?.is_empty(), "still joined users found"); + assert!(store.get_users_with_display_name(room_id, "example").await?.is_empty(), "still display names found"); assert!(store .get_room_account_data_event(room_id, RoomAccountDataEventType::Tag) .await? @@ -545,18 +546,18 @@ macro_rules! statestore_integration_tests { .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id) .await? .is_none()); - assert_eq!( + assert!( store .get_event_room_receipt_events(room_id, ReceiptType::Read, first_receipt_event_id()) .await? - .len(), - 0 + .is_empty(), + "still event recepts in the store" ); store.remove_room(stripped_room_id).await?; - assert_eq!(store.get_room_infos().await?.len(), 0); - assert_eq!(store.get_stripped_room_infos().await?.len(), 0); + assert!(store.get_room_infos().await?.is_empty(), "still room info found"); + assert!(store.get_stripped_room_infos().await?.is_empty(), "still stripped room info found"); Ok(()) } @@ -567,7 +568,7 @@ macro_rules! statestore_integration_tests { let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost"); // Before the first sync the timeline should be empty - assert!(store.room_timeline(room_id).await.unwrap().is_none()); + assert!(store.room_timeline(room_id).await.expect("failed to read timeline").is_none(), "TL wasn't empty"); // Add sync response let sync = SyncResponse::try_from_http_response( @@ -700,10 +701,17 @@ macro_rules! statestore_integration_tests { let timeline = timeline_iter.collect::>>().await; - assert!(timeline - .into_iter() - .zip(stored_events.iter()) - .all(|(a, b)| a.unwrap().event_id() == b.event_id())); + let expected: Vec> = stored_events.iter().map(|a| a.event_id().expect("event id doesn't exist")).collect(); + let found: Vec> = timeline.iter().map(|a| a.as_ref().expect("object missing").event_id().clone().expect("event id missing")).collect(); + + for (idx, (a, b)) in timeline + .into_iter() + .zip(stored_events.iter()) + .enumerate() + { + assert_eq!(a.expect("not a value").event_id(), b.event_id(), "pos {} not equal - expected: {:#?}, but found {:#?}", idx, expected, found); + + } } } diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index 50dd0e36f..8869fff83 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -50,9 +50,6 @@ use ruma::{ EventId, MxcUri, RoomId, UserId, }; -#[cfg(feature = "store_key")] -pub mod store_key; - /// BoxStream of owned Types pub type BoxStream = Pin + Send>>; diff --git a/crates/matrix-sdk-base/src/store/store_key.rs b/crates/matrix-sdk-base/src/store/store_key.rs deleted file mode 100644 index 49729881c..000000000 --- a/crates/matrix-sdk-base/src/store/store_key.rs +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Facilities for StateStore implementations to reuse to manage encrypted -//! state keys - -use std::convert::TryFrom; - -use chacha20poly1305::{ - aead::{Aead, Error as EncryptionError, NewAead}, - ChaCha20Poly1305, Key, Nonce, XChaCha20Poly1305, XNonce, -}; -use hmac::Hmac; -use pbkdf2::pbkdf2; -use rand::{thread_rng, Error as RngError, Fill}; -use serde::{Deserialize, Serialize}; -use sha2::Sha256; -use zeroize::{Zeroize, Zeroizing}; - -use crate::StoreError; - -const VERSION: u8 = 1; -const KEY_SIZE: usize = 32; -const NONCE_SIZE: usize = 12; -const XNONCE_SIZE: usize = 24; -const KDF_SALT_SIZE: usize = 32; -#[cfg(not(test))] -const KDF_ROUNDS: u32 = 200_000; -#[cfg(test)] -const KDF_ROUNDS: u32 = 1000; - -/// Local State Error for Store Keys -/// -/// provides facilities to directly convert into a `StoreError` thus the most -/// common usage is to `.map_err(StoreError::into)` it. -#[derive(Debug, thiserror::Error)] -pub enum Error { - /// A problem de/serializing the JSON - #[error(transparent)] - Serialization(#[from] serde_json::Error), - /// A problem with en/decrypting - #[error("Error encrypting or decrypting an event {0}")] - Encryption(String), - /// Error generating enough random data for a cryptographic operation - #[error("Error generating enough random data for a cryptographic operation")] - Random(#[from] RngError), -} - -#[allow(clippy::from_over_into)] -impl Into for Error { - fn into(self) -> StoreError { - match self { - Error::Serialization(e) => StoreError::Json(e), - Error::Encryption(e) => StoreError::Encryption(e), - Error::Random(_) => StoreError::Encryption(self.to_string()), - } - } -} - -impl From for Error { - fn from(e: EncryptionError) -> Self { - Error::Encryption(e.to_string()) - } -} - -/// Holding the data of the encrypted event -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct EncryptedEvent { - version: u8, - ciphertext: Vec, - nonce: Vec, -} - -/// Version specific info for the key derivation method that is used. -#[derive(Debug, Serialize, Deserialize, PartialEq)] -enum KdfInfo { - Pbkdf2ToChaCha20Poly1305 { - /// The number of PBKDF rounds that were used when deriving the store - /// key. - rounds: u32, - /// The salt that was used when the passphrase was expanded into a store - /// key. - kdf_salt: Vec, - }, -} - -/// Version specific info for encryption method that is used to encrypt our -/// store key. -#[derive(Debug, Serialize, Deserialize, PartialEq)] -enum CipherTextInfo { - ChaCha20Poly1305 { - /// The nonce that was used to encrypt the ciphertext. - nonce: Vec, - /// The encrypted store key. - ciphertext: Vec, - }, -} - -/// An encrypted version of our store key, this can be safely stored in a -/// database. -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct EncryptedStoreKey { - /// Info about the key derivation method that was used to expand the - /// passphrase into an encryption key. - kdf_info: KdfInfo, - /// The ciphertext with it's accompanying additional data that is needed to - /// decrypt the store key. - ciphertext_info: CipherTextInfo, -} - -/// A store key that can be used to encrypt entries in the store. -#[derive(Debug, Zeroize, PartialEq)] -pub struct StoreKey { - inner: Vec, -} - -impl TryFrom> for StoreKey { - type Error = (); - - fn try_from(value: Vec) -> Result { - if value.len() != KEY_SIZE { - Err(()) - } else { - Ok(Self { inner: value }) - } - } -} - -impl StoreKey { - /// Generate a new random store key. - pub fn new() -> Result { - let mut key = vec![0u8; KEY_SIZE]; - let mut rng = thread_rng(); - key.try_fill(&mut rng)?; - - Ok(Self { inner: key }) - } - - /// Expand the given passphrase into a KEY_SIZE long key. - fn expand_key(passphrase: &str, salt: &[u8], rounds: u32) -> Zeroizing> { - let mut key = Zeroizing::from(vec![0u8; KEY_SIZE]); - pbkdf2::>(passphrase.as_bytes(), salt, rounds, &mut *key); - key - } - - /// Get the store key. - fn key(&self) -> &Key { - Key::from_slice(&self.inner) - } - - /// Encrypt and export our store key using the given passphrase. - /// - /// # Arguments - /// - /// * `passphrase` - The passphrase that should be used to encrypt the - /// store key. - pub fn export(&self, passphrase: &str) -> Result { - let mut rng = thread_rng(); - - let mut salt = vec![0u8; KDF_SALT_SIZE]; - salt.try_fill(&mut rng)?; - - let key = StoreKey::expand_key(passphrase, &salt, KDF_ROUNDS); - let key = Key::from_slice(key.as_ref()); - let cipher = ChaCha20Poly1305::new(key); - - let mut nonce = vec![0u8; NONCE_SIZE]; - nonce.try_fill(&mut rng)?; - - let ciphertext = - cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?; - - Ok(EncryptedStoreKey { - kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt }, - ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext }, - }) - } - - fn get_nonce() -> Result, RngError> { - let mut nonce = vec![0u8; XNONCE_SIZE]; - let mut rng = thread_rng(); - - nonce.try_fill(&mut rng)?; - - Ok(nonce) - } - - /// Encrypt the given Event after serializing it with serde_json - pub fn encrypt(&self, event: &impl Serialize) -> Result { - let event = serde_json::to_vec(event)?; - - let nonce = StoreKey::get_nonce()?; - let cipher = XChaCha20Poly1305::new(self.key()); - let xnonce = XNonce::from_slice(&nonce); - - let ciphertext = cipher.encrypt(xnonce, event.as_ref())?; - - Ok(EncryptedEvent { version: VERSION, ciphertext, nonce }) - } - - /// Decrypt the given encrypted event back into the inner - pub fn decrypt Deserialize<'b>>(&self, event: EncryptedEvent) -> Result { - if event.version != VERSION { - return Err(Error::Encryption( - "Error decrypting: Unknown ciphertext version".to_owned(), - )); - } - - let cipher = XChaCha20Poly1305::new(self.key()); - let nonce = XNonce::from_slice(&event.nonce); - let plaintext = cipher.decrypt(nonce, event.ciphertext.as_ref())?; - - Ok(serde_json::from_slice(&plaintext)?) - } - - /// Restore a store key from an encrypted export. - /// - /// # Arguments - /// - /// * `passphrase` - The passphrase that should be used to encrypt the - /// store key. - /// - /// * `encrypted` - The exported and encrypted version of the store key. - pub fn import(passphrase: &str, encrypted: EncryptedStoreKey) -> Result { - let key = match encrypted.kdf_info { - KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds, kdf_salt } => { - Self::expand_key(passphrase, &kdf_salt, rounds) - } - }; - - let key = Key::from_slice(key.as_ref()); - - let decrypted = match encrypted.ciphertext_info { - CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext } => { - let cipher = ChaCha20Poly1305::new(key); - let nonce = Nonce::from_slice(&nonce); - cipher.decrypt(nonce, ciphertext.as_ref())? - } - }; - - Ok(Self { inner: decrypted }) - } -} - -#[cfg(test)] -mod tests { - use serde_json::{json, Value}; - - use super::StoreKey; - - #[test] - fn generating() { - StoreKey::new().unwrap(); - } - - #[test] - fn encrypting() { - let passphrase = "it's a secret to everybody"; - let store_key = StoreKey::new().unwrap(); - - let encrypted = store_key.export(passphrase).unwrap(); - let decrypted = StoreKey::import(passphrase, encrypted).unwrap(); - - assert_eq!(store_key, decrypted); - } - - #[test] - fn encrypting_events() { - let event = json!({ - "content": { - "body": "Bee Gees - Stayin' Alive", - "info": { - "duration": 2140786, - "mimetype": "audio/mpeg", - "size": 1563685 - }, - "msgtype": "m.audio", - "url": "mxc://example.org/ffed755USFFxlgbQYZGtryd" - }, - }); - - let store_key = StoreKey::new().unwrap(); - - let encrypted = store_key.encrypt(&event).unwrap(); - let decrypted: Value = store_key.decrypt(encrypted).unwrap(); - assert_eq!(event, decrypted); - } -} diff --git a/crates/matrix-sdk-indexeddb/Cargo.toml b/crates/matrix-sdk-indexeddb/Cargo.toml index 91d54e51c..0ddf6117f 100644 --- a/crates/matrix-sdk-indexeddb/Cargo.toml +++ b/crates/matrix-sdk-indexeddb/Cargo.toml @@ -11,21 +11,22 @@ encryption = ["matrix-sdk-base/encryption", "matrix-sdk-crypto"] default-target = "wasm32-unknown-unknown" [dependencies] -matrix-sdk-base = { path = "../matrix-sdk-base", features = ["store_key"] } -matrix-sdk-crypto = { path = "../matrix-sdk-crypto", optional = true } -matrix-sdk-common = { path = "../matrix-sdk-common" } -matrix-sdk-store-encryption = { path = "../matrix-sdk-store-encryption" } -thiserror = "1.0.25" anyhow = "1" - +base64 = { version = "0.13.0" } +dashmap = "5.1.0" futures-util = { version = "0.3.15", default-features = false } indexed_db_futures = { version = "0.2.0" } -wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"] } -web-sys = { version = "0.3.35", features = ["IdbKeyRange"] } +matrix-sdk-base = { path = "../matrix-sdk-base" } +matrix-sdk-common = { path = "../matrix-sdk-common" } +matrix-sdk-store-encryption = { path = "../matrix-sdk-store-encryption" } serde = { version = "1.0.126" } serde_json = "1.0.64" -dashmap = "5.1.0" +thiserror = "1.0.25" tracing = "0.1.26" +wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"] } +web-sys = { version = "0.3.35", features = ["IdbKeyRange"] } + +matrix-sdk-crypto = { path = "../matrix-sdk-crypto", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] @@ -36,4 +37,5 @@ getrandom = { version = "0.2", features = ["js"] } matrix-sdk-base = { path = "../matrix-sdk-base", features = ["testing"] } matrix-sdk-crypto = { path = "../matrix-sdk-crypto", features = ["testing"] } matrix-sdk-test = { path = "../matrix-sdk-test" } +uuid = "0.8" wasm-bindgen-test = "0.3.24" diff --git a/crates/matrix-sdk-indexeddb/src/safe_encode.rs b/crates/matrix-sdk-indexeddb/src/safe_encode.rs index 1a6999a37..2d20b5c9e 100644 --- a/crates/matrix-sdk-indexeddb/src/safe_encode.rs +++ b/crates/matrix-sdk-indexeddb/src/safe_encode.rs @@ -1,10 +1,12 @@ #![allow(dead_code)] +use base64::{encode_config as base64_encode, STANDARD_NO_PAD}; use matrix_sdk_base::ruma::events::{ GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, }; use matrix_sdk_common::ruma::{ receipt::ReceiptType, DeviceId, EventId, MxcUri, RoomId, TransactionId, UserId, }; +use matrix_sdk_store_encryption::StoreCipher; use wasm_bindgen::JsValue; use web_sys::IdbKeyRange; @@ -36,7 +38,46 @@ pub trait SafeEncode { /// encode self into a JsValue, internally using `as_encoded_string` /// to escape the value of self. fn encode(&self) -> JsValue { - JsValue::from(self.as_encoded_string()) + self.as_encoded_string().into() + } + + /// encode self into a JsValue, internally using `as_secure_string` + /// to escape the value of self, + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> JsValue { + self.as_secure_string(table_name, store_cipher).into() + } + + /// encode self securely for the given tablename with the given + /// `store_cipher` hash_key, returns the value as a base64 encoded + /// string without any padding. + fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { + base64_encode( + store_cipher.hash_key(table_name, self.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ) + } + + /// encode self into a JsValue, internally using `as_encoded_string` + /// to escape the value of self, and append the given counter + fn encode_with_counter(&self, i: usize) -> JsValue { + format!("{}{}{:0000000x}", self.as_encoded_string(), KEY_SEPARATOR, i).into() + } + + /// encode self into a JsValue, internally using `as_secure_string` + /// to escape the value of self, and append the given counter + fn encode_with_counter_secure( + &self, + table_name: &str, + store_cipher: &StoreCipher, + i: usize, + ) -> JsValue { + format!( + "{}{}{:0000000x}", + self.as_secure_string(table_name, store_cipher), + KEY_SEPARATOR, + i + ) + .into() } /// Encode self into a IdbKeyRange for searching all keys that are @@ -50,6 +91,19 @@ pub trait SafeEncode { ) .map_err(|e| e.as_string().unwrap_or_else(|| "Creating key range failed".to_owned())) } + + fn encode_to_range_secure( + &self, + table_name: &str, + store_cipher: &StoreCipher, + ) -> Result { + let key = self.as_secure_string(table_name, store_cipher); + IdbKeyRange::bound( + &JsValue::from([&key, KEY_SEPARATOR].concat()), + &JsValue::from([&key, RANGE_END].concat()), + ) + .map_err(|e| e.as_string().unwrap_or_else(|| "Creating key range failed".to_owned())) + } } /// Implement SafeEncode for tuple of two elements, separating the escaped @@ -62,6 +116,21 @@ where fn as_encoded_string(&self) -> String { [&self.0.as_encoded_string(), KEY_SEPARATOR, &self.1.as_encoded_string()].concat() } + + fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { + [ + &base64_encode( + store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + KEY_SEPARATOR, + &base64_encode( + store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + ] + .concat() + } } /// Implement SafeEncode for tuple of three elements, separating the escaped @@ -82,6 +151,26 @@ where ] .concat() } + + fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { + [ + &base64_encode( + store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + KEY_SEPARATOR, + &base64_encode( + store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + KEY_SEPARATOR, + &base64_encode( + store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + ] + .concat() + } } /// Implement SafeEncode for tuple of four elements, separating the escaped @@ -105,6 +194,31 @@ where ] .concat() } + + fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { + [ + &base64_encode( + store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + KEY_SEPARATOR, + &base64_encode( + store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + KEY_SEPARATOR, + &base64_encode( + store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + KEY_SEPARATOR, + &base64_encode( + store_cipher.hash_key(table_name, self.3.as_encoded_string().as_bytes()), + STANDARD_NO_PAD, + ), + ] + .concat() + } } impl SafeEncode for String { @@ -195,4 +309,8 @@ impl SafeEncode for usize { fn as_encoded_string(&self) -> String { self.to_string() } + + fn as_secure_string(&self, _table_name: &str, _store_cipher: &StoreCipher) -> String { + self.to_string() + } } diff --git a/crates/matrix-sdk-indexeddb/src/state_store.rs b/crates/matrix-sdk-indexeddb/src/state_store.rs index 9c8dd3553..7461c0805 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store.rs @@ -20,10 +20,7 @@ use indexed_db_futures::prelude::*; use matrix_sdk_base::{ deserialized_responses::{MemberEvent, SyncRoomEvent}, media::{MediaRequest, UniqueKey}, - store::{ - store_key::{self, EncryptedEvent, StoreKey}, - BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError, - }, + store::{BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError}, RoomInfo, }; use matrix_sdk_common::{ @@ -46,24 +43,23 @@ use matrix_sdk_common::{ EventId, MxcUri, RoomId, RoomVersionId, UserId, }, }; +use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher}; use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use wasm_bindgen::JsValue; +use web_sys::IdbKeyRange; use crate::safe_encode::SafeEncode; -#[derive(Debug, Serialize, Deserialize)] -pub enum DatabaseType { - Unencrypted, - Encrypted(store_key::EncryptedStoreKey), -} +#[derive(Clone, Serialize, Deserialize)] +struct StoreKeyWrapper(Vec); #[derive(Debug, thiserror::Error)] pub enum SerializationError { #[error(transparent)] Json(#[from] serde_json::Error), #[error(transparent)] - Encryption(#[from] store_key::Error), + Encryption(#[from] EncryptionError), #[error("DomException {name} ({code}): {message}")] DomException { name: String, message: String, code: u16 }, #[error(transparent)] @@ -86,9 +82,17 @@ impl From for StoreError { SerializationError::Json(e) => StoreError::Json(e), SerializationError::StoreError(e) => e, SerializationError::Encryption(e) => match e { - store_key::Error::Random(e) => StoreError::Encryption(e.to_string()), - store_key::Error::Serialization(e) => StoreError::Json(e), - store_key::Error::Encryption(e) => StoreError::Encryption(e), + EncryptionError::Random(e) => StoreError::Encryption(e.to_string()), + EncryptionError::Serialization(e) => StoreError::Json(e), + EncryptionError::Encryption(e) => StoreError::Encryption(e.to_string()), + EncryptionError::Version(found, expected) => StoreError::Encryption(format!( + "Bad Database Encryption Version: expected {} found {}", + expected, found + )), + EncryptionError::Length(found, expected) => StoreError::Encryption(format!( + "The database key an invalid length: expected {} found {}", + expected, found + )), }, _ => StoreError::Backend(anyhow!(e)), } @@ -138,7 +142,7 @@ mod KEYS { pub struct IndexeddbStore { name: String, pub(crate) inner: IdbDatabase, - store_key: Option, + store_cipher: Option, } impl std::fmt::Debug for IndexeddbStore { @@ -150,7 +154,7 @@ impl std::fmt::Debug for IndexeddbStore { type Result = std::result::Result; impl IndexeddbStore { - async fn open_helper(name: String, store_key: Option) -> Result { + async fn open_helper(name: String, store_cipher: Option) -> Result { // Open my_db v1 let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?; db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { @@ -193,7 +197,7 @@ impl IndexeddbStore { let db: IdbDatabase = db_req.into_future().await?; - Ok(Self { name, inner: db, store_key }) + Ok(Self { name, inner: db, store_cipher }) } #[allow(dead_code)] @@ -205,7 +209,7 @@ impl IndexeddbStore { Ok(Self::inner_open_with_passphrase(name, passphrase).await?) } - async fn inner_open_with_passphrase(name: String, passphrase: &str) -> Result { + pub(crate) async fn inner_open_with_passphrase(name: String, passphrase: &str) -> Result { let name = format!("{:0}::matrix-sdk-state", name); let mut db_req: OpenDbRequest = IdbDatabase::open_u32(&name, 1)?; @@ -225,33 +229,25 @@ impl IndexeddbStore { db.transaction_on_one_with_mode("matrix-sdk-state", IdbTransactionMode::Readwrite)?; let ob = tx.object_store("matrix-sdk-state")?; - let store_key: Option = ob + let cipher = if let Some(StoreKeyWrapper(inner)) = ob .get(&JsValue::from_str(KEYS::STORE_KEY))? .await? - .map(|k| k.into_serde()) - .transpose()?; - - let store_key = if let Some(key) = store_key { - if let DatabaseType::Encrypted(k) = key { - StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)? - } else { - return Err(StoreError::UnencryptedStore.into()); - } + .map(|v| v.into_serde()) + .transpose()? + { + StoreCipher::import(passphrase, &inner)? } else { - let key = StoreKey::new().map_err::(|e| e.into())?; - let encrypted_key = DatabaseType::Encrypted( - key.export(passphrase).map_err::(|e| e.into())?, - ); + let cipher = StoreCipher::new()?; ob.put_key_val( &JsValue::from_str(KEYS::STORE_KEY), - &JsValue::from_serde(&encrypted_key)?, + &JsValue::from_serde(&StoreKeyWrapper(cipher.export(passphrase)?))?, )?; - key + cipher }; tx.await.into_result()?; - IndexeddbStore::open_helper(name, Some(store_key)).await + IndexeddbStore::open_helper(name, Some(cipher)).await } pub async fn open_with_name(name: String) -> StoreResult { @@ -262,8 +258,8 @@ impl IndexeddbStore { &self, event: &impl Serialize, ) -> std::result::Result { - Ok(match self.store_key { - Some(ref key) => JsValue::from_serde(&key.encrypt(event)?)?, + Ok(match self.store_cipher { + Some(ref cipher) => JsValue::from_serde(&cipher.encrypt_value_typed(event)?)?, None => JsValue::from_serde(event)?, }) } @@ -272,15 +268,47 @@ impl IndexeddbStore { &self, event: JsValue, ) -> std::result::Result { - match self.store_key { - Some(ref key) => { - let encrypted: EncryptedEvent = event.into_serde()?; - Ok(key.decrypt(encrypted)?) - } + match self.store_cipher { + Some(ref cipher) => Ok(cipher.decrypt_value_typed(event.into_serde()?)?), None => Ok(event.into_serde()?), } } + fn encode_key(&self, table_name: &str, key: T) -> JsValue + where + T: SafeEncode, + { + match self.store_cipher { + Some(ref cipher) => key.encode_secure(table_name, cipher), + None => key.encode(), + } + } + + fn encode_to_range( + &self, + table_name: &str, + key: T, + ) -> Result + where + T: SafeEncode, + { + match self.store_cipher { + Some(ref cipher) => key.encode_to_range_secure(table_name, cipher), + None => key.encode_to_range(), + } + .map_err(|e| SerializationError::StoreError(StoreError::Backend(anyhow!(e)))) + } + + fn encode_key_with_counter(&self, table_name: &str, key: &T, i: usize) -> JsValue + where + T: SafeEncode, + { + match self.store_cipher { + Some(ref cipher) => key.encode_with_counter_secure(table_name, cipher, i), + None => key.encode_with_counter(i), + } + } + pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { let tx = self .inner @@ -288,7 +316,10 @@ impl IndexeddbStore { let obj = tx.object_store(KEYS::SESSION)?; - obj.put_key_val(&(KEYS::FILTER, filter_name).encode(), &JsValue::from_str(filter_id))?; + obj.put_key_val( + &self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)), + &JsValue::from_str(filter_id), + )?; tx.await.into_result()?; @@ -300,7 +331,7 @@ impl IndexeddbStore { .inner .transaction_on_one_with_mode(KEYS::SESSION, IdbTransactionMode::Readonly)? .object_store(KEYS::SESSION)? - .get(&(KEYS::FILTER, filter_name).encode())? + .get(&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)))? .await? .and_then(|f| f.as_string())) } @@ -373,7 +404,7 @@ impl IndexeddbStore { let store = tx.object_store(KEYS::DISPLAY_NAMES)?; for (room_id, ambiguity_maps) in &changes.ambiguity_maps { for (display_name, map) in ambiguity_maps { - let key = (room_id, display_name).encode(); + let key = self.encode_key(KEYS::DISPLAY_NAMES, (room_id, display_name)); store.put_key_val(&key, &self.serialize_event(&map)?)?; } @@ -383,7 +414,10 @@ impl IndexeddbStore { if !changes.account_data.is_empty() { let store = tx.object_store(KEYS::ACCOUNT_DATA)?; for (event_type, event) in &changes.account_data { - store.put_key_val(&event_type.encode(), &self.serialize_event(&event)?)?; + store.put_key_val( + &self.encode_key(KEYS::ACCOUNT_DATA, event_type), + &self.serialize_event(&event)?, + )?; } } @@ -391,7 +425,7 @@ impl IndexeddbStore { let store = tx.object_store(KEYS::ROOM_ACCOUNT_DATA)?; for (room, events) in &changes.room_account_data { for (event_type, event) in events { - let key = (room, event_type).encode(); + let key = self.encode_key(KEYS::ROOM_ACCOUNT_DATA, (room, event_type)); store.put_key_val(&key, &self.serialize_event(&event)?)?; } } @@ -402,7 +436,7 @@ impl IndexeddbStore { for (room, event_types) in &changes.state { for (event_type, events) in event_types { for (state_key, event) in events { - let key = (room, event_type, state_key).encode(); + let key = self.encode_key(KEYS::ROOM_STATE, (room, event_type, state_key)); store.put_key_val(&key, &self.serialize_event(&event)?)?; } } @@ -412,21 +446,30 @@ impl IndexeddbStore { if !changes.room_infos.is_empty() { let store = tx.object_store(KEYS::ROOM_INFOS)?; for (room_id, room_info) in &changes.room_infos { - store.put_key_val(&room_id.encode(), &self.serialize_event(&room_info)?)?; + store.put_key_val( + &self.encode_key(KEYS::ROOM_INFOS, room_id), + &self.serialize_event(&room_info)?, + )?; } } if !changes.presence.is_empty() { let store = tx.object_store(KEYS::PRESENCE)?; for (sender, event) in &changes.presence { - store.put_key_val(&sender.encode(), &self.serialize_event(&event)?)?; + store.put_key_val( + &self.encode_key(KEYS::PRESENCE, sender), + &self.serialize_event(&event)?, + )?; } } if !changes.stripped_room_infos.is_empty() { let store = tx.object_store(KEYS::STRIPPED_ROOM_INFOS)?; for (room_id, info) in &changes.stripped_room_infos { - store.put_key_val(&room_id.encode(), &self.serialize_event(&info)?)?; + store.put_key_val( + &self.encode_key(KEYS::STRIPPED_ROOM_INFOS, room_id), + &self.serialize_event(&info)?, + )?; } } @@ -434,7 +477,7 @@ impl IndexeddbStore { let store = tx.object_store(KEYS::STRIPPED_MEMBERS)?; for (room, events) in &changes.stripped_members { for event in events.values() { - let key = (room, &event.state_key).encode(); + let key = self.encode_key(KEYS::STRIPPED_MEMBERS, (room, &event.state_key)); store.put_key_val(&key, &self.serialize_event(&event)?)?; } } @@ -445,7 +488,8 @@ impl IndexeddbStore { for (room, event_types) in &changes.stripped_state { for (event_type, events) in event_types { for (state_key, event) in events { - let key = (room, event_type, state_key).encode(); + let key = self + .encode_key(KEYS::STRIPPED_ROOM_STATE, (room, event_type, state_key)); store.put_key_val(&key, &self.serialize_event(&event)?)?; } } @@ -462,27 +506,39 @@ impl IndexeddbStore { let members = tx.object_store(KEYS::MEMBERS)?; for event in events.values() { - let key = (room, &event.state_key).encode(); + let key = (room, &event.state_key); match event.content.membership { MembershipState::Join => { - joined.put_key_val_owned(&key, &event.state_key.encode())?; - invited.delete(&key)?; + joined.put_key_val_owned( + &self.encode_key(KEYS::JOINED_USER_IDS, key), + &self.serialize_event(&event.state_key)?, + )?; + invited.delete(&self.encode_key(KEYS::INVITED_USER_IDS, key))?; } MembershipState::Invite => { - invited.put_key_val_owned(&key, &event.state_key.encode())?; - joined.delete(&key)?; + invited.put_key_val_owned( + &self.encode_key(KEYS::INVITED_USER_IDS, key), + &self.serialize_event(&event.state_key)?, + )?; + joined.delete(&self.encode_key(KEYS::JOINED_USER_IDS, key))?; } _ => { - joined.delete(&key)?; - invited.delete(&key)?; + joined.delete(&self.encode_key(KEYS::JOINED_USER_IDS, key))?; + invited.delete(&self.encode_key(KEYS::INVITED_USER_IDS, key))?; } } - members.put_key_val_owned(&key, &self.serialize_event(&event)?)?; + members.put_key_val_owned( + &self.encode_key(KEYS::MEMBERS, key), + &self.serialize_event(&event)?, + )?; if let Some(profile) = profile_changes.and_then(|p| p.get(&event.state_key)) { - profiles_store.put_key_val_owned(&key, &self.serialize_event(&profile)?)?; + profiles_store.put_key_val_owned( + &self.encode_key(KEYS::PROFILES, key), + &self.serialize_event(&profile)?, + )?; } } } @@ -496,15 +552,20 @@ impl IndexeddbStore { for (event_id, receipts) in &content.0 { for (receipt_type, receipts) in receipts { for (user_id, receipt) in receipts { - let key = (room, receipt_type, user_id).encode(); + let key = self.encode_key( + KEYS::ROOM_USER_RECEIPTS, + (room, receipt_type, user_id), + ); if let Some((old_event, _)) = room_user_receipts.get(&key)?.await?.and_then(|f| { self.deserialize_event::<(Box, Receipt)>(f).ok() }) { - room_event_receipts - .delete(&(room, receipt_type, &old_event, user_id).encode())?; + room_event_receipts.delete(&self.encode_key( + KEYS::ROOM_EVENT_RECEIPTS, + (room, receipt_type, &old_event, user_id), + ))?; } room_user_receipts @@ -512,8 +573,11 @@ impl IndexeddbStore { // Add the receipt to the room event receipts room_event_receipts.put_key_val( - &(room, receipt_type, event_id, user_id).encode(), - &self.serialize_event(&receipt)?, + &self.encode_key( + KEYS::ROOM_EVENT_RECEIPTS, + (room, receipt_type, event_id, user_id), + ), + &self.serialize_event(&(user_id, receipt))?, )?; } } @@ -533,18 +597,19 @@ impl IndexeddbStore { } else { info!("Save new timeline batch from messages response for {}", room_id); } - let room_key = room_id.encode(); - let metadata: Option = if timeline.limited { info!( "Delete stored timeline for {} because the sync response was limited", room_id ); - let range = room_id.encode_to_range().map_err(StoreError::Codec)?; - let stores = - &[&timeline_store, &timeline_metadata_store, &event_id_to_position_store]; - for store in stores { + let stores = &[ + (KEYS::ROOM_TIMELINE, &timeline_store), + (KEYS::ROOM_TIMELINE_METADATA, &timeline_metadata_store), + (KEYS::ROOM_EVENT_ID_TO_POSITION, &event_id_to_position_store), + ]; + for (table_name, store) in stores { + let range = self.encode_to_range(table_name, room_id)?; for key in store.get_all_keys_with_key(&range)?.await?.iter() { store.delete(&key)?; } @@ -553,7 +618,7 @@ impl IndexeddbStore { None } else { let metadata: Option = timeline_metadata_store - .get(&room_key)? + .get(&self.encode_key(KEYS::ROOM_TIMELINE_METADATA, room_id))? .await? .map(|v| v.into_serde()) .transpose()?; @@ -570,7 +635,10 @@ impl IndexeddbStore { let mut delete_timeline = false; for event in &timeline.events { if let Some(event_id) = event.event_id() { - let event_key = (room_id, &event_id).encode(); + let event_key = self.encode_key( + KEYS::ROOM_EVENT_ID_TO_POSITION, + (room_id, event_id), + ); if event_id_to_position_store .count_with_key_owned(event_key)? .await? @@ -588,13 +656,13 @@ impl IndexeddbStore { room_id ); - let range = room_id.encode_to_range().map_err(StoreError::Codec)?; let stores = &[ - &timeline_store, - &timeline_metadata_store, - &event_id_to_position_store, + (KEYS::ROOM_TIMELINE, &timeline_store), + (KEYS::ROOM_TIMELINE_METADATA, &timeline_metadata_store), + (KEYS::ROOM_EVENT_ID_TO_POSITION, &event_id_to_position_store), ]; - for store in stores { + for (table_name, store) in stores { + let range = self.encode_to_range(table_name, room_id)?; for key in store.get_all_keys_with_key(&range)?.await?.iter() { store.delete(&key)?; } @@ -626,7 +694,7 @@ impl IndexeddbStore { if timeline.sync { let room_version = room_infos - .get(&room_key)? + .get(&self.encode_key(KEYS::ROOM_INFOS, room_id))? .await? .map(|r| self.deserialize_event::(r)) .transpose()? @@ -646,7 +714,10 @@ impl IndexeddbStore { ), )) = event.event.deserialize() { - let redacts_key = (room_id, &redaction.redacts).encode(); + let redacts_key = self.encode_key( + KEYS::ROOM_EVENT_ID_TO_POSITION, + (room_id, &redaction.redacts), + ); if let Some(position_key) = event_id_to_position_store.get_owned(redacts_key)?.await? { @@ -673,22 +744,32 @@ impl IndexeddbStore { } metadata.start_position -= 1; - let key = (room_id, &metadata.start_position).encode(); + let key = self.encode_key_with_counter( + KEYS::ROOM_TIMELINE, + room_id, + metadata.start_position, + ); // Only add event with id to the position map if let Some(event_id) = event.event_id() { - let event_key = (room_id, &event_id).encode(); + let event_key = self + .encode_key(KEYS::ROOM_EVENT_ID_TO_POSITION, (room_id, event_id)); event_id_to_position_store.put_key_val(&event_key, &key)?; } - timeline_store.put_key_val_owned(key, &self.serialize_event(&event)?)?; + timeline_store.put_key_val_owned(&key, &self.serialize_event(&event)?)?; } } else { for event in &timeline.events { metadata.end_position += 1; - let key = (room_id, &metadata.end_position).encode(); + let key = self.encode_key_with_counter( + KEYS::ROOM_TIMELINE, + room_id, + metadata.end_position, + ); // Only add event with id to the position map if let Some(event_id) = event.event_id() { - let event_key = (room_id, &event_id).encode(); + let event_key = self + .encode_key(KEYS::ROOM_EVENT_ID_TO_POSITION, (room_id, event_id)); event_id_to_position_store.put_key_val(&event_key, &key)?; } @@ -696,8 +777,10 @@ impl IndexeddbStore { } } - timeline_metadata_store - .put_key_val_owned(room_key, &JsValue::from_serde(&metadata)?)?; + timeline_metadata_store.put_key_val_owned( + &self.encode_key(KEYS::ROOM_TIMELINE_METADATA, room_id), + &JsValue::from_serde(&metadata)?, + )?; } } @@ -708,7 +791,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::PRESENCE, IdbTransactionMode::Readonly)? .object_store(KEYS::PRESENCE)? - .get(&user_id.encode())? + .get(&self.encode_key(KEYS::PRESENCE, user_id))? .await? .map(|f| self.deserialize_event(f)) .transpose() @@ -723,7 +806,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::ROOM_STATE, IdbTransactionMode::Readonly)? .object_store(KEYS::ROOM_STATE)? - .get(&(room_id, &event_type, state_key).encode())? + .get(&self.encode_key(KEYS::ROOM_STATE, (room_id, event_type, state_key)))? .await? .map(|f| self.deserialize_event(f)) .transpose() @@ -734,7 +817,7 @@ impl IndexeddbStore { room_id: &RoomId, event_type: StateEventType, ) -> Result>> { - let range = (room_id, &event_type).encode_to_range().map_err(StoreError::Codec)?; + let range = self.encode_to_range(KEYS::ROOM_STATE, (room_id, event_type))?; Ok(self .inner .transaction_on_one_with_mode(KEYS::ROOM_STATE, IdbTransactionMode::Readonly)? @@ -754,7 +837,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::PROFILES, IdbTransactionMode::Readonly)? .object_store(KEYS::PROFILES)? - .get(&(room_id, user_id).encode())? + .get(&self.encode_key(KEYS::PROFILES, (room_id, user_id)))? .await? .map(|f| self.deserialize_event(f)) .transpose() @@ -768,31 +851,19 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::MEMBERS, IdbTransactionMode::Readonly)? .object_store(KEYS::MEMBERS)? - .get(&(room_id, state_key).encode())? + .get(&self.encode_key(KEYS::MEMBERS, (room_id, state_key)))? .await? .map(|f| self.deserialize_event(f)) .transpose() } pub async fn get_user_ids_stream(&self, room_id: &RoomId) -> Result>> { - let range = room_id.encode_to_range().map_err(StoreError::Codec)?; - let skip = room_id.as_encoded_string().len() + 1; - Ok(self - .inner - .transaction_on_one_with_mode(KEYS::MEMBERS, IdbTransactionMode::Readonly)? - .object_store(KEYS::MEMBERS)? - .get_all_keys_with_key(&range)? - .await? - .iter() - .filter_map(|key| match key.as_string() { - Some(k) => UserId::parse(&k[skip..]).ok(), - _ => None, - }) - .collect::>()) + Ok([self.get_invited_user_ids(room_id).await?, self.get_joined_user_ids(room_id).await?] + .concat()) } pub async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result>> { - let range = room_id.encode_to_range().map_err(StoreError::Codec)?; + let range = self.encode_to_range(KEYS::INVITED_USER_IDS, room_id)?; let entries = self .inner .transaction_on_one_with_mode(KEYS::INVITED_USER_IDS, IdbTransactionMode::Readonly)? @@ -807,7 +878,7 @@ impl IndexeddbStore { } pub async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result>> { - let range = room_id.encode_to_range().map_err(StoreError::Codec)?; + let range = self.encode_to_range(KEYS::JOINED_USER_IDS, room_id)?; Ok(self .inner .transaction_on_one_with_mode(KEYS::JOINED_USER_IDS, IdbTransactionMode::Readonly)? @@ -855,7 +926,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::DISPLAY_NAMES, IdbTransactionMode::Readonly)? .object_store(KEYS::DISPLAY_NAMES)? - .get(&(room_id, display_name).encode())? + .get(&self.encode_key(KEYS::DISPLAY_NAMES, (room_id, display_name)))? .await? .map(|f| { self.deserialize_event::>>(f) @@ -871,7 +942,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::ACCOUNT_DATA, IdbTransactionMode::Readonly)? .object_store(KEYS::ACCOUNT_DATA)? - .get(&JsValue::from_str(&event_type.to_string()))? + .get(&self.encode_key(KEYS::ACCOUNT_DATA, event_type))? .await? .map(|f| self.deserialize_event(f).map_err::(|e| e)) .transpose() @@ -885,7 +956,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::ROOM_ACCOUNT_DATA, IdbTransactionMode::Readonly)? .object_store(KEYS::ROOM_ACCOUNT_DATA)? - .get(&(room_id.as_str(), &event_type.to_string()).encode())? + .get(&self.encode_key(KEYS::ROOM_ACCOUNT_DATA, (room_id, event_type)))? .await? .map(|f| self.deserialize_event(f).map_err::(|e| e)) .transpose() @@ -900,7 +971,7 @@ impl IndexeddbStore { self.inner .transaction_on_one_with_mode(KEYS::ROOM_USER_RECEIPTS, IdbTransactionMode::Readonly)? .object_store(KEYS::ROOM_USER_RECEIPTS)? - .get(&(room_id.as_str(), receipt_type.as_ref(), user_id.as_str()).encode())? + .get(&self.encode_key(KEYS::ROOM_USER_RECEIPTS, (room_id, receipt_type, user_id)))? .await? .map(|f| self.deserialize_event(f)) .transpose() @@ -912,38 +983,25 @@ impl IndexeddbStore { receipt_type: ReceiptType, event_id: &EventId, ) -> Result, Receipt)>> { - let key = (room_id, &receipt_type, event_id); - let prefix_len = key.as_encoded_string().len() + 1; - let range = key.encode_to_range().map_err(StoreError::Codec)?; + let range = + self.encode_to_range(KEYS::ROOM_EVENT_RECEIPTS, (room_id, &receipt_type, event_id))?; let tx = self.inner.transaction_on_one_with_mode( KEYS::ROOM_EVENT_RECEIPTS, IdbTransactionMode::Readonly, )?; let store = tx.object_store(KEYS::ROOM_EVENT_RECEIPTS)?; - let mut all = Vec::new(); - for k in store.get_all_keys_with_key(&range)?.await?.iter() { - // FIXME: we should probably parallelize this... - let res = store - .get(&k)? - .await? - .ok_or_else(|| StoreError::Codec(format!("no data at {:?}", k)))?; - let u = if let Some(k_str) = k.as_string() { - UserId::parse(&k_str[prefix_len..]) - .map_err(|e| StoreError::Codec(format!("{:?}", e)))? - } else { - return Err(StoreError::Codec(format!("{:?}", k)).into()); - }; - let r = self - .deserialize_event::(res) - .map_err(|e| StoreError::Codec(e.to_string()))?; - all.push((u, r)); - } - Ok(all) + Ok(store + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event(f).ok()) + .collect::>()) } async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { - let key = (&request.source.unique_key(), &request.format.unique_key()).encode(); + let key = self + .encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key())); let tx = self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?; @@ -953,7 +1011,8 @@ impl IndexeddbStore { } async fn get_media_content(&self, request: &MediaRequest) -> Result>> { - let key = (&request.source.unique_key(), &request.format.unique_key()).encode(); + let key = self + .encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key())); self.inner .transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readonly)? .object_store(KEYS::MEDIA)? @@ -997,7 +1056,8 @@ impl IndexeddbStore { } async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { - let key = (&request.source.unique_key(), &request.format.unique_key()).encode(); + let key = self + .encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key())); let tx = self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?; @@ -1007,7 +1067,7 @@ impl IndexeddbStore { } async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { - let range = uri.encode_to_range().map_err(StoreError::Codec)?; + let range = self.encode_to_range(KEYS::MEDIA, uri)?; let tx = self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?; let store = tx.object_store(KEYS::MEDIA)?; @@ -1050,14 +1110,13 @@ impl IndexeddbStore { .inner .transaction_on_multi_with_mode(&all_stores, IdbTransactionMode::Readwrite)?; - let room_key = room_id.encode(); for store_name in direct_stores { - tx.object_store(store_name)?.delete(&room_key)?; + tx.object_store(store_name)?.delete(&self.encode_key(store_name, room_id))?; } - let range = room_id.encode_to_range().map_err(StoreError::Codec)?; for store_name in prefixed_stores { let store = tx.object_store(store_name)?; + let range = self.encode_to_range(store_name, room_id)?; for key in store.get_all_keys_with_key(&range)?.await?.iter() { store.delete(&key)?; } @@ -1069,7 +1128,6 @@ impl IndexeddbStore { &self, room_id: &RoomId, ) -> Result>, Option)>> { - let key = room_id.encode(); let tx = self.inner.transaction_on_multi_with_mode( &[KEYS::ROOM_TIMELINE, KEYS::ROOM_TIMELINE_METADATA], IdbTransactionMode::Readonly, @@ -1077,16 +1135,23 @@ impl IndexeddbStore { let timeline = tx.object_store(KEYS::ROOM_TIMELINE)?; let metadata = tx.object_store(KEYS::ROOM_TIMELINE_METADATA)?; - let metadata: Option = - metadata.get(&key)?.await?.map(|v| v.into_serde()).transpose()?; - if metadata.is_none() { - info!("No timeline for {} was previously stored", room_id); - return Ok(None); - } - let end_token = metadata.and_then(|m| m.end); + let tlm: TimelineMetadata = match metadata + .get(&self.encode_key(KEYS::ROOM_TIMELINE_METADATA, room_id))? + .await? + .map(|v| v.into_serde()) + .transpose()? + { + Some(tl) => tl, + _ => { + info!("No timeline for {} was previously stored", room_id); + return Ok(None); + } + }; + + let end_token = tlm.end; #[allow(clippy::needless_collect)] let timeline: Vec> = timeline - .get_all_with_key(&key)? + .get_all_with_key(&self.encode_to_range(KEYS::ROOM_TIMELINE, room_id)?)? .await? .iter() .map(|v| self.deserialize_event(v).map_err(|e| e.into())) @@ -1280,3 +1345,23 @@ mod tests { statestore_integration_tests! { integration } } + +#[cfg(test)] +mod encrypted_tests { + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + use matrix_sdk_base::statestore_integration_tests; + use uuid::Uuid; + + use super::{IndexeddbStore, Result, StoreCipher}; + + async fn get_store() -> Result { + let db_name = + format!("test-state-encrypted-{}", Uuid::new_v4().to_hyphenated().to_string()); + let key = StoreCipher::new()?; + Ok(IndexeddbStore::open_helper(db_name, Some(key)).await?) + } + + statestore_integration_tests! { integration } +} diff --git a/crates/matrix-sdk-sled/Cargo.toml b/crates/matrix-sdk-sled/Cargo.toml index f1c685e84..19582f757 100644 --- a/crates/matrix-sdk-sled/Cargo.toml +++ b/crates/matrix-sdk-sled/Cargo.toml @@ -12,7 +12,7 @@ crypto-store = ["matrix-sdk-crypto"] [dependencies] futures-core = "0.3.15" futures-util = { version = "0.3.15", default-features = false } -matrix-sdk-base = { path = "../matrix-sdk-base", features = ["store_key"], optional = true } +matrix-sdk-base = { path = "../matrix-sdk-base", optional = true } matrix-sdk-common = { path = "../matrix-sdk-common" } matrix-sdk-crypto = { path = "../matrix-sdk-crypto", optional = true } matrix-sdk-store-encryption = { path = "../matrix-sdk-store-encryption" } diff --git a/crates/matrix-sdk-sled/src/cryptostore.rs b/crates/matrix-sdk-sled/src/cryptostore.rs index d16324285..eaa9b781b 100644 --- a/crates/matrix-sdk-sled/src/cryptostore.rs +++ b/crates/matrix-sdk-sled/src/cryptostore.rs @@ -24,10 +24,7 @@ use dashmap::DashSet; use matrix_sdk_common::{ async_trait, locks::Mutex, - ruma::{ - events::{room_key_request::RequestedKeyInfo, secret::request::SecretName}, - DeviceId, RoomId, TransactionId, UserId, - }, + ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, RoomId, TransactionId, UserId}, }; use matrix_sdk_crypto::{ olm::{ @@ -69,12 +66,20 @@ impl EncodeKey for InboundGroupSession { fn encode(&self) -> Vec { (self.room_id(), self.sender_key(), self.session_id()).encode() } + + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + (self.room_id(), self.sender_key(), self.session_id()) + .encode_secure(table_name, store_cipher) + } } impl EncodeKey for OutboundGroupSession { fn encode(&self) -> Vec { self.room_id().encode() } + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + self.room_id().encode_secure(table_name, store_cipher) + } } impl EncodeKey for Session { @@ -85,6 +90,15 @@ impl EncodeKey for Session { [sender_key.as_bytes(), &[ENCODE_SEPARATOR], session_id.as_bytes(), &[ENCODE_SEPARATOR]] .concat() } + + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + let sender_key = + store_cipher.hash_key(table_name, self.sender_key().to_base64().as_bytes()); + let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes()); + + [sender_key.as_slice(), &[ENCODE_SEPARATOR], session_id.as_slice(), &[ENCODE_SEPARATOR]] + .concat() + } } impl EncodeKey for SecretInfo { @@ -94,31 +108,6 @@ impl EncodeKey for SecretInfo { SecretInfo::SecretRequest(s) => s.encode(), } } -} - -impl EncodeKey for RequestedKeyInfo { - fn encode(&self) -> Vec { - (&self.room_id, &self.sender_key, &self.algorithm, &self.session_id).encode() - } -} - -impl EncodeKey for ReadOnlyDevice { - fn encode(&self) -> Vec { - (self.user_id(), self.device_id()).encode() - } -} - -impl EncodeSecureKey for (&UserId, &DeviceId) { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let user_id = store_cipher.hash_key(table_name, self.0.as_bytes()); - let device_id = store_cipher.hash_key(table_name, self.1.as_bytes()); - - [user_id.as_slice(), &[ENCODE_SEPARATOR], device_id.as_slice(), &[ENCODE_SEPARATOR]] - .concat() - } -} - -impl EncodeSecureKey for SecretInfo { fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { match self { SecretInfo::KeyRequest(k) => k.encode_secure(table_name, store_cipher), @@ -127,15 +116,10 @@ impl EncodeSecureKey for SecretInfo { } } -impl EncodeSecureKey for SecretName { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let name = store_cipher.hash_key(table_name, self.as_ref().as_bytes()); - - [name.as_slice(), &[ENCODE_SEPARATOR]].concat() +impl EncodeKey for RequestedKeyInfo { + fn encode(&self) -> Vec { + (&self.room_id, &self.sender_key, &self.algorithm, &self.session_id).encode() } -} - -impl EncodeSecureKey for RequestedKeyInfo { fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { let room_id = store_cipher.hash_key(table_name, self.room_id.as_bytes()); let sender_key = store_cipher.hash_key(table_name, self.sender_key.as_bytes()); @@ -156,81 +140,15 @@ impl EncodeSecureKey for RequestedKeyInfo { } } -impl EncodeSecureKey for UserId { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let user_id = store_cipher.hash_key(table_name, self.as_bytes()); - - [user_id.as_slice(), &[0xff]].concat() +impl EncodeKey for ReadOnlyDevice { + fn encode(&self) -> Vec { + (self.user_id(), self.device_id()).encode() } -} - -impl EncodeSecureKey for ReadOnlyDevice { fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { (self.user_id(), self.device_id()).encode_secure(table_name, store_cipher) } } -impl EncodeSecureKey for str { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let key = store_cipher.hash_key(table_name, self.as_bytes()); - [key.as_slice(), &[ENCODE_SEPARATOR]].concat() - } -} - -impl EncodeSecureKey for OutboundGroupSession { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - self.room_id().encode_secure(table_name, store_cipher) - } -} - -impl EncodeSecureKey for RoomId { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let room_id = store_cipher.hash_key(table_name, self.as_bytes()); - - [room_id.as_slice(), &[ENCODE_SEPARATOR]].concat() - } -} - -impl EncodeSecureKey for InboundGroupSession { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - (self.room_id(), self.sender_key(), self.session_id()) - .encode_secure(table_name, store_cipher) - } -} - -impl EncodeSecureKey for (&RoomId, &str, &str) { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let first = store_cipher.hash_key(table_name, self.0.as_bytes()); - let second = store_cipher.hash_key(table_name, self.1.as_bytes()); - let third = store_cipher.hash_key(table_name, self.2.as_bytes()); - - [ - first.as_slice(), - &[ENCODE_SEPARATOR], - second.as_slice(), - &[ENCODE_SEPARATOR], - third.as_slice(), - &[ENCODE_SEPARATOR], - ] - .concat() - } -} - -impl EncodeSecureKey for Session { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { - let sender_key = - store_cipher.hash_key(table_name, self.sender_key().to_base64().as_bytes()); - let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes()); - - [sender_key.as_slice(), &[ENCODE_SEPARATOR], session_id.as_slice(), &[ENCODE_SEPARATOR]] - .concat() - } -} - -trait EncodeSecureKey { - fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec; -} - #[derive(Clone, Debug)] pub struct AccountInfo { user_id: Arc, @@ -332,11 +250,7 @@ impl SledStore { } } - fn encode_key( - &self, - table_name: &str, - key: &T, - ) -> Vec { + fn encode_key(&self, table_name: &str, key: T) -> Vec { if let Some(store_cipher) = &*self.store_cipher { key.encode_secure(table_name, store_cipher).to_vec() } else { diff --git a/crates/matrix-sdk-sled/src/encode_key.rs b/crates/matrix-sdk-sled/src/encode_key.rs index a57578a90..7bdcfdb21 100644 --- a/crates/matrix-sdk-sled/src/encode_key.rs +++ b/crates/matrix-sdk-sled/src/encode_key.rs @@ -1,91 +1,163 @@ +use std::{borrow::Cow, ops::Deref}; + use matrix_sdk_common::ruma::{ - events::{secret::request::SecretName, GlobalAccountDataEventType, StateEventType}, - DeviceId, EventId, MxcUri, RoomId, TransactionId, UserId, + events::{ + secret::request::SecretName, GlobalAccountDataEventType, RoomAccountDataEventType, + StateEventType, + }, + receipt::ReceiptType, + DeviceId, EventEncryptionAlgorithm, EventId, MxcUri, RoomId, TransactionId, UserId, }; +use matrix_sdk_store_encryption::StoreCipher; + pub const ENCODE_SEPARATOR: u8 = 0xff; pub trait EncodeKey { - fn encode(&self) -> Vec; + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + unimplemented!() + } + + fn encode(&self) -> Vec { + [self.encode_as_bytes().deref(), &[ENCODE_SEPARATOR]].concat() + } + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + let key = store_cipher.hash_key(table_name, &self.encode_as_bytes()); + [key.as_slice(), &[ENCODE_SEPARATOR]].concat() + } } impl EncodeKey for &T { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + T::encode_as_bytes(self) + } fn encode(&self) -> Vec { T::encode(self) } + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + T::encode_secure(self, table_name, store_cipher) + } } impl EncodeKey for Box { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + T::encode_as_bytes(self) + } fn encode(&self) -> Vec { T::encode(self) } + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + T::encode_secure(self, table_name, store_cipher) + } } impl EncodeKey for str { - fn encode(&self) -> Vec { - [self.as_bytes(), &[ENCODE_SEPARATOR]].concat() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_bytes().into() } } impl EncodeKey for String { - fn encode(&self) -> Vec { - self.as_str().encode() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_str().as_bytes().into() } } impl EncodeKey for DeviceId { - fn encode(&self) -> Vec { - self.as_str().encode() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_str().as_bytes().into() } } impl EncodeKey for EventId { - fn encode(&self) -> Vec { - self.as_str().encode() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_str().as_bytes().into() } } impl EncodeKey for RoomId { - fn encode(&self) -> Vec { - self.as_str().encode() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_str().as_bytes().into() } } impl EncodeKey for TransactionId { - fn encode(&self) -> Vec { - self.as_str().encode() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_str().as_bytes().into() } } impl EncodeKey for MxcUri { - fn encode(&self) -> Vec { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { let s: &str = self.as_ref(); - s.encode() + s.as_bytes().into() } } impl EncodeKey for SecretName { - fn encode(&self) -> Vec { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { let s: &str = self.as_ref(); - s.encode() + s.as_bytes().into() + } +} + +impl EncodeKey for ReceiptType { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + let s: &str = self.as_ref(); + s.as_bytes().into() + } +} + +impl EncodeKey for EventEncryptionAlgorithm { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + let s: &str = self.as_ref(); + s.as_bytes().into() + } +} + +impl EncodeKey for RoomAccountDataEventType { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.to_string().as_bytes().to_vec().into() } } impl EncodeKey for UserId { - fn encode(&self) -> Vec { - self.as_str().encode() + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.as_str().as_bytes().into() + } +} + +impl EncodeKey for StateEventType { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.to_string().as_bytes().to_vec().into() + } +} + +impl EncodeKey for GlobalAccountDataEventType { + fn encode_as_bytes(&self) -> Cow<'_, [u8]> { + self.to_string().as_bytes().to_vec().into() } } impl EncodeKey for (A, B) where - A: AsRef, - B: AsRef, + A: EncodeKey, + B: EncodeKey, { fn encode(&self) -> Vec { [ - self.0.as_ref().as_bytes(), + self.0.encode_as_bytes().deref(), &[ENCODE_SEPARATOR], - self.1.as_ref().as_bytes(), + self.1.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + ] + .concat() + } + + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + [ + store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(), &[ENCODE_SEPARATOR], ] .concat() @@ -94,17 +166,29 @@ where impl EncodeKey for (A, B, C) where - A: AsRef, - B: AsRef, - C: AsRef, + A: EncodeKey, + B: EncodeKey, + C: EncodeKey, { fn encode(&self) -> Vec { [ - self.0.as_ref().as_bytes(), + self.0.encode_as_bytes().deref(), &[ENCODE_SEPARATOR], - self.1.as_ref().as_bytes(), + self.1.encode_as_bytes().deref(), &[ENCODE_SEPARATOR], - self.2.as_ref().as_bytes(), + self.2.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + ] + .concat() + } + + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + [ + store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.2.encode_as_bytes()).as_slice(), &[ENCODE_SEPARATOR], ] .concat() @@ -113,41 +197,36 @@ where impl EncodeKey for (A, B, C, D) where - A: AsRef, - B: AsRef, - C: AsRef, - D: AsRef, + A: EncodeKey, + B: EncodeKey, + C: EncodeKey, + D: EncodeKey, { fn encode(&self) -> Vec { [ - self.0.as_ref().as_bytes(), + self.0.encode_as_bytes().deref(), &[ENCODE_SEPARATOR], - self.1.as_ref().as_bytes(), + self.1.encode_as_bytes().deref(), &[ENCODE_SEPARATOR], - self.2.as_ref().as_bytes(), + self.2.encode_as_bytes().deref(), &[ENCODE_SEPARATOR], - self.3.as_ref().as_bytes(), + self.3.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + ] + .concat() + } + + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + [ + store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.2.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.3.encode_as_bytes()).as_slice(), &[ENCODE_SEPARATOR], ] .concat() } } - -impl EncodeKey for StateEventType { - fn encode(&self) -> Vec { - self.to_string().encode() - } -} - -impl EncodeKey for GlobalAccountDataEventType { - fn encode(&self) -> Vec { - self.to_string().encode() - } -} - -#[cfg(feature = "state-store")] -pub fn encode_key_with_usize>(s: A, i: usize) -> Vec { - // FIXME: Not portable across architectures - [s.as_ref().as_bytes(), &[ENCODE_SEPARATOR], i.to_be_bytes().as_ref(), &[ENCODE_SEPARATOR]] - .concat() -} diff --git a/crates/matrix-sdk-sled/src/state_store.rs b/crates/matrix-sdk-sled/src/state_store.rs index 70f4c369d..2fbe51c00 100644 --- a/crates/matrix-sdk-sled/src/state_store.rs +++ b/crates/matrix-sdk-sled/src/state_store.rs @@ -14,7 +14,6 @@ use std::{ collections::BTreeSet, - convert::TryInto, path::{Path, PathBuf}, sync::Arc, time::Instant, @@ -23,15 +22,12 @@ use std::{ use anyhow::anyhow; use async_stream::stream; use futures_core::stream::Stream; -use futures_util::stream::{self, TryStreamExt}; +use futures_util::stream::{self, StreamExt, TryStreamExt}; use matrix_sdk_base::{ deserialized_responses::{MemberEvent, SyncRoomEvent}, media::{MediaRequest, UniqueKey}, ruma::events::room::redaction::SyncRoomRedactionEvent, - store::{ - store_key::{self, EncryptedEvent, StoreKey}, - BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError, - }, + store::{BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError}, RoomInfo, }; use matrix_sdk_common::{ @@ -52,6 +48,7 @@ use matrix_sdk_common::{ EventId, MxcUri, RoomId, RoomVersionId, UserId, }, }; +use matrix_sdk_store_encryption::{Error as KeyEncryptionError, StoreCipher}; use serde::{Deserialize, Serialize}; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, @@ -62,22 +59,16 @@ use tracing::{info, warn}; #[cfg(feature = "crypto-store")] use super::OpenStoreError; -use crate::encode_key::{encode_key_with_usize, EncodeKey, ENCODE_SEPARATOR}; +use crate::encode_key::{EncodeKey, ENCODE_SEPARATOR}; #[cfg(feature = "crypto-store")] pub use crate::CryptoStore; -#[derive(Debug, Serialize, Deserialize)] -pub enum DatabaseType { - Unencrypted, - Encrypted(store_key::EncryptedStoreKey), -} - #[derive(Debug, thiserror::Error)] pub enum SledStoreError { #[error(transparent)] Json(#[from] serde_json::Error), #[error(transparent)] - Encryption(#[from] store_key::Error), + Encryption(#[from] KeyEncryptionError), #[error(transparent)] StoreError(#[from] StoreError), #[error(transparent)] @@ -104,9 +95,17 @@ impl Into for SledStoreError { SledStoreError::Json(e) => StoreError::Json(e), SledStoreError::Identifier(e) => StoreError::Identifier(e), SledStoreError::Encryption(e) => match e { - store_key::Error::Random(e) => StoreError::Encryption(e.to_string()), - store_key::Error::Serialization(e) => StoreError::Json(e), - store_key::Error::Encryption(e) => StoreError::Encryption(e), + KeyEncryptionError::Random(e) => StoreError::Encryption(e.to_string()), + KeyEncryptionError::Serialization(e) => StoreError::Json(e), + KeyEncryptionError::Encryption(e) => StoreError::Encryption(e.to_string()), + KeyEncryptionError::Version(found, expected) => StoreError::Encryption(format!( + "Bad Database Encryption Version: expected {} found {}", + expected, found + )), + KeyEncryptionError::Length(found, expected) => StoreError::Encryption(format!( + "The database key an invalid length: expected {} found {}", + expected, found + )), }, SledStoreError::StoreError(e) => e, _ => StoreError::Backend(anyhow!(self)), @@ -114,26 +113,36 @@ impl Into for SledStoreError { } } +const ACCOUNT_DATA: &str = "account-data"; +const CUSTOM: &str = "custom"; +const DISPLAY_NAME: &str = "display-name"; +const INVITED_USER_ID: &str = "invited-user-id"; +const JOINED_USER_ID: &str = "joined-user-id"; +const MEDIA: &str = "media"; +const MEMBER: &str = "member"; +const PRESENCE: &str = "presence"; +const PROFILE: &str = "profile"; +const ROOM_ACCOUNT_DATA: &str = "room-account-data"; +const ROOM_EVENT_ID_POSITION: &str = "room-event-id-to-position"; +const ROOM_EVENT_RECEIPT: &str = "room-event-receipt"; +const ROOM_INFO: &str = "room-info"; +const ROOM_STATE: &str = "room-state"; +const ROOM_USER_RECEIPT: &str = "room-user-receipt"; +const ROOM: &str = "room"; +const SESSION: &str = "session"; +const STRIPPED_ROOM_INFO: &str = "stripped-room-info"; +const STRIPPED_ROOM_MEMBER: &str = "stripped-room-member"; +const STRIPPED_ROOM_STATE: &str = "stripped-room-state"; +const TIMELINE_METADATA: &str = "timeline-metadata"; +const TIMELINE: &str = "timeline"; + type Result = std::result::Result; -/// Get the value at `position` in encoded `key`. -/// -/// The key must have been encoded with the `EncodeKey` trait. `position` -/// corresponds to the position in the tuple before the key was encoded. If it -/// wasn't encoded in a tuple, use `0`. -/// -/// Returns `None` if there is no key at `position`. -pub fn decode_key_value(key: &[u8], position: usize) -> Option { - let values: Vec<&[u8]> = key.split(|v| *v == ENCODE_SEPARATOR).collect(); - - values.get(position).map(|s| String::from_utf8_lossy(s).to_string()) -} - #[derive(Clone)] pub struct SledStore { path: Option, pub(crate) inner: Db, - store_key: Arc>, + store_cipher: Arc>, session: Tree, account_data: Tree, members: Tree, @@ -168,40 +177,44 @@ impl std::fmt::Debug for SledStore { } impl SledStore { - fn open_helper(db: Db, path: Option, store_key: Option) -> Result { - let session = db.open_tree("session")?; - let account_data = db.open_tree("account_data")?; + fn open_helper( + db: Db, + path: Option, + store_cipher: Option, + ) -> Result { + let session = db.open_tree(SESSION)?; + let account_data = db.open_tree(ACCOUNT_DATA)?; - let members = db.open_tree("members")?; - let profiles = db.open_tree("profiles")?; - let display_names = db.open_tree("display_names")?; - let joined_user_ids = db.open_tree("joined_user_ids")?; - let invited_user_ids = db.open_tree("invited_user_ids")?; + let members = db.open_tree(MEMBER)?; + let profiles = db.open_tree(PROFILE)?; + let display_names = db.open_tree(DISPLAY_NAME)?; + let joined_user_ids = db.open_tree(JOINED_USER_ID)?; + let invited_user_ids = db.open_tree(INVITED_USER_ID)?; - let room_state = db.open_tree("room_state")?; - let room_info = db.open_tree("room_infos")?; - let presence = db.open_tree("presence")?; - let room_account_data = db.open_tree("room_account_data")?; + let room_state = db.open_tree(ROOM_STATE)?; + let room_info = db.open_tree(ROOM_INFO)?; + let presence = db.open_tree(PRESENCE)?; + let room_account_data = db.open_tree(ROOM_ACCOUNT_DATA)?; - let stripped_room_infos = db.open_tree("stripped_room_infos")?; - let stripped_members = db.open_tree("stripped_members")?; - let stripped_room_state = db.open_tree("stripped_room_state")?; + let stripped_room_infos = db.open_tree(STRIPPED_ROOM_INFO)?; + let stripped_members = db.open_tree(STRIPPED_ROOM_MEMBER)?; + let stripped_room_state = db.open_tree(STRIPPED_ROOM_STATE)?; - let room_user_receipts = db.open_tree("room_user_receipts")?; - let room_event_receipts = db.open_tree("room_event_receipts")?; + let room_user_receipts = db.open_tree(ROOM_USER_RECEIPT)?; + let room_event_receipts = db.open_tree(ROOM_EVENT_RECEIPT)?; - let media = db.open_tree("media")?; + let media = db.open_tree(MEDIA)?; - let custom = db.open_tree("custom")?; + let custom = db.open_tree(CUSTOM)?; - let room_timeline = db.open_tree("room_timeline")?; - let room_timeline_metadata = db.open_tree("room_timeline_metadata")?; - let room_event_id_to_position = db.open_tree("room_event_id_to_position")?; + let room_timeline = db.open_tree(TIMELINE)?; + let room_timeline_metadata = db.open_tree(TIMELINE_METADATA)?; + let room_event_id_to_position = db.open_tree(ROOM_EVENT_ID_POSITION)?; Ok(Self { path, inner: db, - store_key: store_key.into(), + store_cipher: store_cipher.into(), session, account_data, members, @@ -233,6 +246,20 @@ impl SledStore { SledStore::open_helper(db, None, None).map_err(|e| e.into()) } + // testing only + #[cfg(test)] + fn open_encrypted() -> StoreResult { + let db = + Config::new().temporary(true).open().map_err(|e| StoreError::Backend(anyhow!(e)))?; + + SledStore::open_helper( + db, + None, + Some(StoreCipher::new().expect("can't create store cipher")), + ) + .map_err(|e| e.into()) + } + pub fn open_with_passphrase(path: impl AsRef, passphrase: &str) -> StoreResult { Self::inner_open_with_passphrase(path, passphrase).map_err(|e| e.into()) } @@ -241,27 +268,15 @@ impl SledStore { let path = path.as_ref().join("matrix-sdk-state"); let db = Config::new().temporary(false).path(&path).open()?; - let store_key: Option = db - .get("store_key".encode())? - .map(|k| serde_json::from_slice(&k).map_err(StoreError::Json)) - .transpose()?; - - let store_key = if let Some(key) = store_key { - if let DatabaseType::Encrypted(k) = key { - StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)? - } else { - return Err(SledStoreError::StoreError(StoreError::UnencryptedStore)); - } + let store_cipher = if let Some(inner) = db.get("store_cipher".encode())? { + StoreCipher::import(passphrase, &inner)? } else { - let key = StoreKey::new().map_err::(|e| e.into())?; - let encrypted_key = DatabaseType::Encrypted( - key.export(passphrase).map_err::(|e| e.into())?, - ); - db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?; - key + let cipher = StoreCipher::new()?; + db.insert("store_cipher".encode(), cipher.export(passphrase)?)?; + cipher }; - SledStore::open_helper(db, Some(path), Some(store_key)) + SledStore::open_helper(db, Some(path), Some(store_cipher)) } pub fn open_with_path(path: impl AsRef) -> StoreResult { @@ -287,9 +302,8 @@ impl SledStore { } fn serialize_event(&self, event: &impl Serialize) -> Result, SledStoreError> { - if let Some(key) = &*self.store_key { - let encrypted = key.encrypt(event)?; - Ok(serde_json::to_vec(&encrypted)?) + if let Some(key) = &*self.store_cipher { + Ok(key.encrypt_value(event)?) } else { Ok(serde_json::to_vec(event)?) } @@ -299,24 +313,40 @@ impl SledStore { &self, event: &[u8], ) -> Result { - if let Some(key) = &*self.store_key { - let encrypted: EncryptedEvent = serde_json::from_slice(event)?; - Ok(key.decrypt(encrypted)?) + if let Some(key) = &*self.store_cipher { + Ok(key.decrypt_value(event)?) } else { Ok(serde_json::from_slice(event)?) } } - pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.session.insert(("filter", filter_name).encode(), filter_id)?; + fn encode_key(&self, table_name: &str, key: T) -> Vec { + if let Some(store_cipher) = &*self.store_cipher { + key.encode_secure(table_name, store_cipher).to_vec() + } else { + key.encode() + } + } + fn encode_key_with_counter(&self, tablename: &str, s: A, i: usize) -> Vec { + [ + &self.encode_key(tablename, s), + [ENCODE_SEPARATOR].as_slice(), + i.to_be_bytes().as_ref(), + [ENCODE_SEPARATOR].as_slice(), + ] + .concat() + } + + pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { + self.session.insert(self.encode_key(SESSION, ("filter", filter_name)), filter_id)?; Ok(()) } pub async fn get_filter(&self, filter_name: &str) -> Result> { Ok(self .session - .get(("filter", filter_name).encode())? + .get(self.encode_key(SESSION, ("filter", filter_name)))? .map(|f| String::from_utf8_lossy(&f).to_string())) } @@ -371,25 +401,31 @@ impl SledStore { let profile_changes = changes.profiles.get(room); for event in events.values() { - let key = (room, &event.state_key).encode(); + let key = (room, &event.state_key); match event.content.membership { MembershipState::Join => { - joined.insert(key.as_slice(), event.state_key.as_str())?; - invited.remove(key.as_slice())?; + joined.insert( + self.encode_key(JOINED_USER_ID, &key), + event.state_key.as_str(), + )?; + invited.remove(self.encode_key(INVITED_USER_ID, &key))?; } MembershipState::Invite => { - invited.insert(key.as_slice(), event.state_key.as_str())?; - joined.remove(key.as_slice())?; + invited.insert( + self.encode_key(INVITED_USER_ID, &key), + event.state_key.as_str(), + )?; + joined.remove(self.encode_key(JOINED_USER_ID, &key))?; } _ => { - joined.remove(key.as_slice())?; - invited.remove(key.as_slice())?; + joined.remove(self.encode_key(JOINED_USER_ID, &key))?; + invited.remove(self.encode_key(INVITED_USER_ID, &key))?; } } members.insert( - key.as_slice(), + self.encode_key(MEMBER, &key), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -398,7 +434,7 @@ impl SledStore { profile_changes.and_then(|p| p.get(&event.state_key)) { profiles.insert( - key.as_slice(), + self.encode_key(PROFILE, &key), self.serialize_event(&profile) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -409,7 +445,7 @@ impl SledStore { for (room_id, ambiguity_maps) in &changes.ambiguity_maps { for (display_name, map) in ambiguity_maps { display_names.insert( - (room_id, display_name).encode(), + self.encode_key(DISPLAY_NAME, (room_id, display_name)), self.serialize_event(&map) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -418,7 +454,7 @@ impl SledStore { for (event_type, event) in &changes.account_data { account_data.insert( - event_type.encode(), + self.encode_key(ACCOUNT_DATA, event_type), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -427,7 +463,7 @@ impl SledStore { for (room, events) in &changes.room_account_data { for (event_type, event) in events { room_account_data.insert( - (room, event_type.to_string()).encode(), + self.encode_key(ROOM_ACCOUNT_DATA, &(room, event_type)), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -438,7 +474,7 @@ impl SledStore { for (event_type, events) in event_types { for (state_key, event) in events { state.insert( - (room, event_type.to_string(), state_key).encode(), + self.encode_key(ROOM_STATE, (room, event_type, state_key)), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -448,7 +484,7 @@ impl SledStore { for (room_id, room_info) in &changes.room_infos { rooms.insert( - room_id.encode(), + self.encode_key(ROOM, room_id), self.serialize_event(room_info) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -456,7 +492,7 @@ impl SledStore { for (sender, event) in &changes.presence { presence.insert( - sender.encode(), + self.encode_key(PRESENCE, sender), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -464,7 +500,7 @@ impl SledStore { for (room_id, info) in &changes.stripped_room_infos { striped_rooms.insert( - room_id.encode(), + self.encode_key(STRIPPED_ROOM_INFO, room_id), self.serialize_event(&info) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -473,7 +509,10 @@ impl SledStore { for (room, events) in &changes.stripped_members { for event in events.values() { stripped_members.insert( - (room, event.state_key.to_string()).encode(), + self.encode_key( + STRIPPED_ROOM_MEMBER, + (room, event.state_key.to_string()), + ), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -484,7 +523,10 @@ impl SledStore { for (event_type, events) in event_types { for (state_key, event) in events { stripped_state.insert( - (room, event_type.to_string(), state_key).encode(), + self.encode_key( + STRIPPED_ROOM_STATE, + (room, event_type.to_string(), state_key), + ), self.serialize_event(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -507,7 +549,10 @@ impl SledStore { for (user_id, receipt) in receipts { // Add the receipt to the room user receipts if let Some(old) = room_user_receipts.insert( - (room, receipt_type, user_id).encode(), + self.encode_key( + ROOM_USER_RECEIPT, + (room, receipt_type, user_id), + ), self.serialize_event(&(event_id, receipt)) .map_err(ConflictableTransactionError::Abort)?, )? { @@ -515,15 +560,19 @@ impl SledStore { let (old_event, _): (Box, Receipt) = self .deserialize_event(&old) .map_err(ConflictableTransactionError::Abort)?; - room_event_receipts.remove( - (room, receipt_type, old_event, user_id).encode(), - )?; + room_event_receipts.remove(self.encode_key( + ROOM_EVENT_RECEIPT, + (room, receipt_type, old_event, user_id), + ))?; } // Add the receipt to the room event receipts room_event_receipts.insert( - (room, receipt_type, event_id, user_id).encode(), - self.serialize_event(receipt) + self.encode_key( + ROOM_EVENT_RECEIPT, + (room, receipt_type, event_id, user_id), + ), + self.serialize_event(&(user_id, receipt)) .map_err(ConflictableTransactionError::Abort)?, )?; } @@ -548,7 +597,7 @@ impl SledStore { pub async fn get_presence_event(&self, user_id: &UserId) -> Result>> { let db = self.clone(); - let key = user_id.encode(); + let key = self.encode_key(PRESENCE, user_id); spawn_blocking(move || db.presence.get(key)?.map(|e| db.deserialize_event(&e)).transpose()) .await? } @@ -560,7 +609,7 @@ impl SledStore { state_key: &str, ) -> Result>> { let db = self.clone(); - let key = (room_id, event_type.to_string(), state_key).encode(); + let key = self.encode_key(ROOM_STATE, (room_id, event_type.to_string(), state_key)); spawn_blocking(move || { db.room_state.get(key)?.map(|e| db.deserialize_event(&e)).transpose() }) @@ -573,7 +622,7 @@ impl SledStore { event_type: StateEventType, ) -> Result>> { let db = self.clone(); - let key = (room_id, event_type.to_string()).encode(); + let key = self.encode_key(ROOM_STATE, (room_id, event_type.to_string())); spawn_blocking(move || { db.room_state .scan_prefix(key) @@ -589,7 +638,7 @@ impl SledStore { user_id: &UserId, ) -> Result> { let db = self.clone(); - let key = (room_id, user_id).encode(); + let key = self.encode_key(PROFILE, (room_id, user_id)); spawn_blocking(move || db.profiles.get(key)?.map(|p| db.deserialize_event(&p)).transpose()) .await? } @@ -600,7 +649,7 @@ impl SledStore { state_key: &UserId, ) -> Result> { let db = self.clone(); - let key = (room_id, state_key).encode(); + let key = self.encode_key(MEMBER, (room_id, state_key)); spawn_blocking(move || db.members.get(key)?.map(|v| db.deserialize_event(&v)).transpose()) .await? } @@ -609,29 +658,10 @@ impl SledStore { &self, room_id: &RoomId, ) -> StoreResult>>> { - let decode = |key: &[u8]| -> StoreResult> { - let mut iter = key.split(|c| c == &ENCODE_SEPARATOR); - // Our key is a the room id separated from the user id by a null - // byte, discard the first value of the split. - iter.next(); - - let user_id = iter.next().expect("User ids weren't properly encoded"); - - Ok(UserId::parse(String::from_utf8_lossy(user_id).to_string())?) - }; - - let members = self.members.clone(); - let key = room_id.encode(); - - spawn_blocking(move || { - stream::iter( - members - .scan_prefix(key) - .map(move |u| decode(&u.map_err(|e| StoreError::Backend(anyhow!(e)))?.0)), - ) - }) - .await - .map_err(|e| StoreError::Backend(anyhow!(e))) + Ok(self + .get_joined_user_ids(room_id) + .await? + .chain(self.get_invited_user_ids(room_id).await?)) } pub async fn get_invited_user_ids( @@ -639,7 +669,7 @@ impl SledStore { room_id: &RoomId, ) -> StoreResult>>> { let db = self.clone(); - let key = room_id.encode(); + let key = self.encode_key(INVITED_USER_ID, room_id); spawn_blocking(move || { stream::iter(db.invited_user_ids.scan_prefix(key).map(|u| { UserId::parse( @@ -658,7 +688,7 @@ impl SledStore { room_id: &RoomId, ) -> StoreResult>>> { let db = self.clone(); - let key = room_id.encode(); + let key = self.encode_key(JOINED_USER_ID, room_id); spawn_blocking(move || { stream::iter(db.joined_user_ids.scan_prefix(key).map(|u| { UserId::parse( @@ -696,7 +726,7 @@ impl SledStore { display_name: &str, ) -> Result>> { let db = self.clone(); - let key = (room_id, display_name).encode(); + let key = self.encode_key(DISPLAY_NAME, (room_id, display_name)); spawn_blocking(move || { Ok(db .display_names @@ -713,7 +743,7 @@ impl SledStore { event_type: GlobalAccountDataEventType, ) -> Result>> { let db = self.clone(); - let key = event_type.encode(); + let key = self.encode_key(ACCOUNT_DATA, event_type); spawn_blocking(move || { db.account_data.get(key)?.map(|m| db.deserialize_event(&m)).transpose() }) @@ -726,7 +756,7 @@ impl SledStore { event_type: RoomAccountDataEventType, ) -> Result>> { let db = self.clone(); - let key = (room_id, event_type.to_string()).encode(); + let key = self.encode_key(ROOM_ACCOUNT_DATA, (room_id, event_type)); spawn_blocking(move || { db.room_account_data.get(key)?.map(|m| db.deserialize_event(&m)).transpose() }) @@ -740,7 +770,7 @@ impl SledStore { user_id: &UserId, ) -> Result, Receipt)>> { let db = self.clone(); - let key = (room_id, receipt_type, user_id).encode(); + let key = self.encode_key(ROOM_USER_RECEIPT, (room_id, receipt_type, user_id)); spawn_blocking(move || { db.room_user_receipts.get(key)?.map(|m| db.deserialize_event(&m)).transpose() }) @@ -754,19 +784,14 @@ impl SledStore { event_id: &EventId, ) -> StoreResult, Receipt)>> { let db = self.clone(); - let key = (room_id, receipt_type, event_id).encode(); + let key = self.encode_key(ROOM_EVENT_RECEIPT, (room_id, receipt_type, event_id)); spawn_blocking(move || { db.room_event_receipts .scan_prefix(key) + .values() .map(|u| { - u.map_err(|e| StoreError::Backend(anyhow!(e))).and_then(|(key, value)| { - db.deserialize_event(&value) - // TODO remove this unwrapping - .map(|receipt| { - (decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt) - }) - .map_err(|e| StoreError::Backend(anyhow!(e))) - }) + let v = u.map_err(|e| StoreError::Backend(anyhow!(e)))?; + db.deserialize_event(&v).map_err(|e| StoreError::Backend(anyhow!(e))) }) .collect() }) @@ -775,8 +800,10 @@ impl SledStore { } async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { - self.media - .insert((request.source.unique_key(), request.format.unique_key()).encode(), data)?; + self.media.insert( + self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())), + data, + )?; self.inner.flush_async().await?; @@ -785,7 +812,8 @@ impl SledStore { async fn get_media_content(&self, request: &MediaRequest) -> Result>> { let db = self.clone(); - let key = (request.source.unique_key(), request.format.unique_key()).encode(); + let key = + self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())); spawn_blocking(move || Ok(db.media.get(key)?.map(|m| m.to_vec()))).await? } @@ -804,13 +832,15 @@ impl SledStore { } async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { - self.media.remove((request.source.unique_key(), request.format.unique_key()).encode())?; + self.media.remove( + self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())), + )?; Ok(()) } async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { - let keys = self.media.scan_prefix(uri.encode()).keys(); + let keys = self.media.scan_prefix(self.encode_key(MEDIA, uri)).keys(); let mut batch = sled::Batch::default(); for key in keys { @@ -821,60 +851,75 @@ impl SledStore { } async fn remove_room(&self, room_id: &RoomId) -> Result<()> { - let room_key = room_id.encode(); - let mut members_batch = sled::Batch::default(); - for key in self.members.scan_prefix(room_key.as_slice()).keys() { + for key in self.members.scan_prefix(self.encode_key(MEMBER, room_id)).keys() { members_batch.remove(key?) } let mut profiles_batch = sled::Batch::default(); - for key in self.profiles.scan_prefix(room_key.as_slice()).keys() { + for key in self.profiles.scan_prefix(self.encode_key(PROFILE, room_id)).keys() { profiles_batch.remove(key?) } let mut display_names_batch = sled::Batch::default(); - for key in self.display_names.scan_prefix(room_key.as_slice()).keys() { + for key in self.display_names.scan_prefix(self.encode_key(DISPLAY_NAME, room_id)).keys() { display_names_batch.remove(key?) } let mut joined_user_ids_batch = sled::Batch::default(); - for key in self.joined_user_ids.scan_prefix(room_key.as_slice()).keys() { + for key in self.joined_user_ids.scan_prefix(self.encode_key(JOINED_USER_ID, room_id)).keys() + { joined_user_ids_batch.remove(key?) } let mut invited_user_ids_batch = sled::Batch::default(); - for key in self.invited_user_ids.scan_prefix(room_key.as_slice()).keys() { + for key in + self.invited_user_ids.scan_prefix(self.encode_key(INVITED_USER_ID, room_id)).keys() + { invited_user_ids_batch.remove(key?) } let mut room_state_batch = sled::Batch::default(); - for key in self.room_state.scan_prefix(room_key.as_slice()).keys() { + for key in self.room_state.scan_prefix(self.encode_key(ROOM_STATE, room_id)).keys() { room_state_batch.remove(key?) } let mut room_account_data_batch = sled::Batch::default(); - for key in self.room_account_data.scan_prefix(room_key.as_slice()).keys() { + for key in + self.room_account_data.scan_prefix(self.encode_key(ROOM_ACCOUNT_DATA, room_id)).keys() + { room_account_data_batch.remove(key?) } let mut stripped_members_batch = sled::Batch::default(); - for key in self.stripped_members.scan_prefix(room_key.as_slice()).keys() { + for key in + self.stripped_members.scan_prefix(self.encode_key(STRIPPED_ROOM_MEMBER, room_id)).keys() + { stripped_members_batch.remove(key?) } let mut stripped_room_state_batch = sled::Batch::default(); - for key in self.stripped_room_state.scan_prefix(room_key.as_slice()).keys() { + for key in self + .stripped_room_state + .scan_prefix(self.encode_key(STRIPPED_ROOM_STATE, room_id)) + .keys() + { stripped_room_state_batch.remove(key?) } let mut room_user_receipts_batch = sled::Batch::default(); - for key in self.room_user_receipts.scan_prefix(room_key.as_slice()).keys() { + for key in + self.room_user_receipts.scan_prefix(self.encode_key(ROOM_USER_RECEIPT, room_id)).keys() + { room_user_receipts_batch.remove(key?) } let mut room_event_receipts_batch = sled::Batch::default(); - for key in self.room_event_receipts.scan_prefix(room_key.as_slice()).keys() { + for key in self + .room_event_receipts + .scan_prefix(self.encode_key(ROOM_EVENT_RECEIPT, room_id)) + .keys() + { room_event_receipts_batch.remove(key?) } @@ -909,8 +954,8 @@ impl SledStore { room_user_receipts, room_event_receipts, )| { - rooms.remove(room_key.as_slice())?; - stripped_rooms.remove(room_key.as_slice())?; + rooms.remove(self.encode_key(ROOM, room_id))?; + stripped_rooms.remove(self.encode_key(STRIPPED_ROOM_INFO, room_id))?; members.apply_batch(&members_batch)?; profiles.apply_batch(&profiles_batch)?; @@ -942,7 +987,7 @@ impl SledStore { room_id: &RoomId, ) -> Result>, Option)>> { let db = self.clone(); - let key = room_id.encode(); + let key = self.encode_key(TIMELINE_METADATA, room_id); let r_id = room_id.to_owned(); let metadata: Option = db .room_timeline_metadata @@ -963,7 +1008,7 @@ impl SledStore { info!("Found previously stored timeline for {}, with end token {:?}", r_id, end_token); let stream = stream! { - while let Ok(Some(item)) = db.room_timeline.get(&encode_key_with_usize(&r_id, position)) { + while let Ok(Some(item)) = db.room_timeline.get(&db.encode_key_with_counter(TIMELINE, &r_id, position)) { position += 1; yield db.deserialize_event(&item).map_err(SledStoreError::from).map_err(|e| e.into()); } @@ -973,16 +1018,19 @@ impl SledStore { } async fn remove_room_timeline(&self, room_id: &RoomId) -> Result<()> { - let room_key = room_id.encode(); info!("Remove stored timeline for {}", room_id); let mut timeline_batch = sled::Batch::default(); - for key in self.room_timeline.scan_prefix(room_key.as_slice()).keys() { + for key in self.room_timeline.scan_prefix(self.encode_key(TIMELINE, &room_id)).keys() { timeline_batch.remove(key?) } let mut event_id_to_position_batch = sled::Batch::default(); - for key in self.room_event_id_to_position.scan_prefix(room_key.as_slice()).keys() { + for key in self + .room_event_id_to_position + .scan_prefix(self.encode_key(ROOM_EVENT_ID_POSITION, &room_id)) + .keys() + { event_id_to_position_batch.remove(key?) } @@ -990,7 +1038,8 @@ impl SledStore { (&self.room_timeline, &self.room_timeline_metadata, &self.room_event_id_to_position) .transaction( |(room_timeline, room_timeline_metadata, room_event_id_to_position)| { - room_timeline_metadata.remove(room_key.as_slice())?; + room_timeline_metadata + .remove(self.encode_key(TIMELINE_METADATA, &room_id))?; room_timeline.apply_batch(&timeline_batch)?; room_event_id_to_position.apply_batch(&event_id_to_position_batch)?; @@ -1015,7 +1064,6 @@ impl SledStore { } else { info!("Save new timeline batch from messages response for {}", room_id); } - let room_key = room_id.encode(); let metadata: Option = if timeline.limited { info!( @@ -1027,7 +1075,7 @@ impl SledStore { } else { let metadata: Option = self .room_timeline_metadata - .get(room_key.as_slice())? + .get(self.encode_key(TIMELINE_METADATA, &room_id))? .map(|v| serde_json::from_slice(&v).map_err(StoreError::Json)) .transpose()?; if let Some(mut metadata) = metadata { @@ -1043,7 +1091,8 @@ impl SledStore { let mut delete_timeline = false; for event in &timeline.events { if let Some(event_id) = event.event_id() { - let event_key = (room_id, event_id).encode(); + let event_key = + self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, event_id)); if self.room_event_id_to_position.contains_key(event_key)? { delete_timeline = true; break; @@ -1082,7 +1131,7 @@ impl SledStore { }; let room_version = self .room_info - .get(room_id.encode())? + .get(&self.encode_key(ROOM_INFO, room_id))? .map(|r| self.deserialize_event::(&r)) .transpose()? .and_then(|info| info.room_version().cloned()) @@ -1100,7 +1149,8 @@ impl SledStore { )), )) = event.event.deserialize() { - let redacts_key = (room_id, redaction.redacts).encode(); + let redacts_key = + self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, redaction.redacts)); if let Some(position_key) = self.room_event_id_to_position.get(redacts_key)? { @@ -1125,28 +1175,35 @@ impl SledStore { } metadata.start_position -= 1; - let key = encode_key_with_usize(room_id, metadata.start_position); + let key = + self.encode_key_with_counter(TIMELINE, room_id, metadata.start_position); timeline_batch.insert(key.as_slice(), self.serialize_event(&event)?); // Only add event with id to the position map if let Some(event_id) = event.event_id() { - let event_key = (room_id, event_id).encode(); + let event_key = + self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, event_id)); event_id_to_position_batch.insert(event_key.as_slice(), key.as_slice()); } } } else { for event in &timeline.events { metadata.end_position += 1; - let key = encode_key_with_usize(room_id, metadata.end_position); + let key = + self.encode_key_with_counter(TIMELINE, room_id, metadata.end_position); timeline_batch.insert(key.as_slice(), self.serialize_event(&event)?); // Only add event with id to the position map if let Some(event_id) = event.event_id() { - let event_key = (room_id, event_id).encode(); + let event_key = + self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, event_id)); event_id_to_position_batch.insert(event_key.as_slice(), key.as_slice()); } } } - timeline_metadata_batch.insert(room_key, serde_json::to_vec(&metadata)?); + timeline_metadata_batch.insert( + self.encode_key(TIMELINE_METADATA, &room_id), + serde_json::to_vec(&metadata)?, + ); } let ret: Result<(), TransactionError> = @@ -1355,3 +1412,16 @@ mod tests { statestore_integration_tests! { integration } } + +#[cfg(test)] +mod encrypted_tests { + use matrix_sdk_base::statestore_integration_tests; + + use super::{SledStore, StateStore, StoreResult}; + + async fn get_store() -> StoreResult { + SledStore::open_encrypted().map_err(Into::into) + } + + statestore_integration_tests! { integration } +} diff --git a/crates/matrix-sdk-store-encryption/src/lib.rs b/crates/matrix-sdk-store-encryption/src/lib.rs index d3e54282a..0e0ad8d17 100644 --- a/crates/matrix-sdk-store-encryption/src/lib.rs +++ b/crates/matrix-sdk-store-encryption/src/lib.rs @@ -284,16 +284,83 @@ impl StoreCipher { /// # Result::<_, anyhow::Error>::Ok(()) }; /// ``` pub fn encrypt_value(&self, value: &impl Serialize) -> Result, Error> { - let mut data = serde_json::to_vec(value)?; + Ok(serde_json::to_vec(&self.encrypt_value_typed(value)?)?) + } + /// Encrypt a value before it is inserted into the key/value store. + /// + /// A value can be decrypted using the + /// [`StoreCipher::decrypt_value_typed()`] method. This is the lower + /// level function to `encrypt_value`, but returns the + /// full `EncryptdValue`-type + /// + /// # Arguments + /// + /// * `value` - A value that should be encrypted, any value that implements + /// `Serialize` can be given to this method. The value will be serialized as + /// json before it is encrypted. + /// + /// + /// # Examples + /// + /// ```no_run + /// # let example = || { + /// use matrix_sdk_store_encryption::StoreCipher; + /// use serde_json::{json, value::Value}; + /// + /// let store_cipher = StoreCipher::new()?; + /// + /// let value = json!({ + /// "some": "data", + /// }); + /// + /// let encrypted = store_cipher.encrypt_value_typed(&value)?; + /// let decrypted: Value = store_cipher.decrypt_value_typed(encrypted)?; + /// + /// assert_eq!(value, decrypted); + /// # anyhow::Ok(()) }; + /// ``` + pub fn encrypt_value_typed(&self, value: &impl Serialize) -> Result { + let data = serde_json::to_vec(value)?; + self.encrypt_value_data(data) + } + + /// Encrypt some data before it is inserted into the key/value store. + /// + /// A value can be decrypted using the [`StoreCipher::decrypt_value_data()`] + /// method. This is the lower level function to `encrypt_value` + /// + /// # Arguments + /// + /// * `data` - A value that should be encrypted, encoded as a `Vec` + /// + /// # Examples + /// + /// ``` + /// # let example = || { + /// use matrix_sdk_store_encryption::StoreCipher; + /// use serde_json::{json, value::Value}; + /// + /// let store_cipher = StoreCipher::new()?; + /// + /// let value = serde_json::to_vec(&json!({ + /// "some": "data", + /// }))?; + /// + /// let encrypted = store_cipher.encrypt_value_data(value.clone())?; + /// let decrypted = store_cipher.decrypt_value_data(encrypted)?; + /// + /// assert_eq!(value, decrypted); + /// # Result::<_, anyhow::Error>::Ok(()) }; + /// ``` + pub fn encrypt_value_data(&self, mut data: Vec) -> Result { let nonce = Keys::get_nonce()?; let cipher = XChaCha20Poly1305::new(self.inner.encryption_key()); let ciphertext = cipher.encrypt(XNonce::from_slice(&nonce), data.as_ref())?; data.zeroize(); - - Ok(serde_json::to_vec(&EncryptedValue { version: VERSION, ciphertext, nonce })?) + Ok(EncryptedValue { version: VERSION, ciphertext, nonce }) } /// Decrypt a value after it was fetchetd from the key/value store. @@ -325,22 +392,91 @@ impl StoreCipher { /// /// assert_eq!(value, decrypted); /// # Result::<_, anyhow::Error>::Ok(()) }; + /// ``` pub fn decrypt_value Deserialize<'b>>(&self, value: &[u8]) -> Result { let value: EncryptedValue = serde_json::from_slice(value)?; + self.decrypt_value_typed(value) + } + /// Decrypt a value after it was fetchetd from the key/value store. + /// + /// A value can be encrypted using the + /// [`StoreCipher::encrypt_value_typed()`] method. Lower level method to + /// [`StoreCipher::decrypt_value_typed()`] + /// + /// # Arguments + /// + /// * `value` - The EncryptedValue of a value that should be decrypted. + /// + /// The method will deserialize the decrypted value into the expected type. + /// + /// # Examples + /// + /// ``` + /// # let example = || { + /// use matrix_sdk_store_encryption::StoreCipher; + /// use serde_json::{json, value::Value}; + /// + /// let store_cipher = StoreCipher::new()?; + /// + /// let value = json!({ + /// "some": "data", + /// }); + /// + /// let encrypted = store_cipher.encrypt_value_typed(&value)?; + /// let decrypted: Value = store_cipher.decrypt_value_typed(encrypted)?; + /// + /// assert_eq!(value, decrypted); + /// # Result::<_, anyhow::Error>::Ok(()) }; + /// ``` + pub fn decrypt_value_typed Deserialize<'b>>( + &self, + value: EncryptedValue, + ) -> Result { + let mut plaintext = self.decrypt_value_data(value)?; + let ret = serde_json::from_slice(&plaintext); + plaintext.zeroize(); + Ok(ret?) + } + + /// Decrypt a value after it was fetchetd from the key/value store. + /// + /// A value can be encrypted using the [`StoreCipher::encrypt_value_data()`] + /// method. Lower level method to [`StoreCipher::decrypt_value()`]. + /// + /// # Arguments + /// + /// * `value` - The EncryptedValue of a value that should be decrypted. + /// + /// The method will return the raw decrypted value + /// + /// # Examples + /// + /// ``` + /// # let example = || { + /// use matrix_sdk_store_encryption::StoreCipher; + /// use serde_json::{json, value::Value}; + /// + /// let store_cipher = StoreCipher::new()?; + /// + /// let value = serde_json::to_vec(&json!({ + /// "some": "data", + /// }))?; + /// + /// let encrypted = store_cipher.encrypt_value_data(value.clone())?; + /// let decrypted = store_cipher.decrypt_value_data(encrypted)?; + /// + /// assert_eq!(value, decrypted); + /// # Result::<_, anyhow::Error>::Ok(()) }; + /// ``` + pub fn decrypt_value_data(&self, value: EncryptedValue) -> Result, Error> { if value.version != VERSION { return Err(Error::Version(VERSION, value.version)); } let cipher = XChaCha20Poly1305::new(self.inner.encryption_key()); let nonce = XNonce::from_slice(&value.nonce); - let mut plaintext = cipher.decrypt(nonce, value.ciphertext.as_ref())?; - - let ret = serde_json::from_slice(&plaintext); - - plaintext.zeroize(); - - Ok(ret?) + Ok(cipher.decrypt(nonce, value.ciphertext.as_ref())?) } /// Expand the given passphrase into a KEY_SIZE long key. @@ -362,8 +498,10 @@ impl MacKey { } } +/// Encrypted value, ready for storage, as created by the +/// [`StoreCipher::encrypt_value_data()`] #[derive(Debug, Serialize, Deserialize, PartialEq)] -struct EncryptedValue { +pub struct EncryptedValue { version: u8, ciphertext: Vec, nonce: [u8; XNONCE_SIZE],