feat(state_store): Store keys and values encrypted, too

Merge pull request #586 from matrix-org/ben-remove-store-key

This removes the store_key and replaces it with the new key cipher for both sled and indexeddb state stores. If they are given a passphrase, we are now encrypted/hashing all keys and values stored in the store, no leaking of metadata in any key or value anymore.
This commit is contained in:
Benjamin Kampmann
2022-04-19 16:46:11 +02:00
committed by GitHub
13 changed files with 997 additions and 894 deletions

View File

@@ -17,17 +17,9 @@ rustdoc-args = ["--cfg", "docsrs"]
[features]
default = []
encryption = ["matrix-sdk-crypto", "store_key"]
encryption = ["matrix-sdk-crypto"]
qrcode = ["matrix-sdk-crypto/qrcode"]
# Store Key is a helper feature for state store implementations
store_key = [
"pbkdf2",
"hmac",
"sha2",
"rand",
"chacha20poly1305",
]
# helpers for testing features build upon this
testing = [ "http" ]

View File

@@ -7,5 +7,4 @@ The following crate feature flags are available:
* `encryption`: Enables end-to-end encryption support in the library.
* `qrcode`: Enbles QRcode generation and reading code
* `store_key`: Provides extra facilities for StateStores to manage store keys
* `testing`: provides facilities and functions for tests, in particular for integration testing store implementations. ATTENTION: do not ever use outside of tests, we do not provide any stability warantees on these, these are merely helpers. If you find you _need_ any function provided here outside of tests, please open a Github Issue and inform us about your use case for us to consider.

View File

@@ -277,18 +277,18 @@ macro_rules! statestore_integration_tests {
assert!(store.get_sync_token().await?.is_some());
assert!(store.get_presence_event(user_id).await?.is_some());
assert_eq!(store.get_room_infos().await?.len(), 1);
assert_eq!(store.get_stripped_room_infos().await?.len(), 1);
assert_eq!(store.get_room_infos().await?.len(), 1, "Expected to find 1 room info ");
assert_eq!(store.get_stripped_room_infos().await?.len(), 1, "Expected to find 1 stripped room info");
assert!(store.get_account_data_event(GlobalAccountDataEventType::PushRules).await?.is_some());
assert!(store.get_state_event(room_id, StateEventType::RoomName, "").await?.is_some());
assert_eq!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), 1);
assert_eq!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), 1, "Expected to find 1 room topic");
assert!(store.get_profile(room_id, user_id).await?.is_some());
assert!(store.get_member_event(room_id, user_id).await?.is_some());
assert_eq!(store.get_user_ids(room_id).await?.len(), 2);
assert_eq!(store.get_invited_user_ids(room_id).await?.len(), 1);
assert_eq!(store.get_joined_user_ids(room_id).await?.len(), 1);
assert_eq!(store.get_users_with_display_name(room_id, "example").await?.len(), 2);
assert_eq!(store.get_user_ids(room_id).await?.len(), 2, "Expected to find 2 members for room");
assert_eq!(store.get_invited_user_ids(room_id).await?.len(), 1, "Expected to find 1 invited user ids");
assert_eq!(store.get_joined_user_ids(room_id).await?.len(), 1, "Expected to find 1 joined user ids");
assert_eq!(store.get_users_with_display_name(room_id, "example").await?.len(), 2, "Expected to find 2 display names for room");
assert!(store
.get_room_account_data_event(room_id, RoomAccountDataEventType::Tag)
.await?
@@ -302,8 +302,7 @@ macro_rules! statestore_integration_tests {
.get_event_room_receipt_events(room_id, ReceiptType::Read, first_receipt_event_id())
.await?
.len(),
1
);
1, "Expected to find 1 read receipt");
Ok(())
}
@@ -325,7 +324,7 @@ macro_rules! statestore_integration_tests {
assert!(store.get_member_event(room_id, user_id).await.unwrap().is_some());
let members = store.get_user_ids(room_id).await.unwrap();
assert!(!members.is_empty())
assert!(!members.is_empty(), "We expected to find members for the room")
}
#[async_test]
@@ -354,7 +353,7 @@ macro_rules! statestore_integration_tests {
#[async_test]
async fn test_receipts_saving() {
let store = get_store().await.unwrap();
let store = get_store().await.expect("creating store failed");
let room_id = room_id!("!test_receipts_saving:localhost");
@@ -370,7 +369,7 @@ macro_rules! statestore_integration_tests {
}
}
}))
.unwrap();
.expect("json creation failed");
let second_receipt_event = serde_json::from_value(json!({
second_event_id.clone(): {
@@ -381,68 +380,70 @@ macro_rules! statestore_integration_tests {
}
}
}))
.unwrap();
.expect("json creation failed");
assert!(store
.get_user_room_receipt_event(room_id, ReceiptType::Read, user_id())
.await
.unwrap()
.expect("failed to read user room receipt")
.is_none());
assert!(store
.get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.expect("failed to read user room receipt for 1")
.is_empty());
assert!(store
.get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.expect("failed to read user room receipt for 2")
.is_empty());
let mut changes = StateChanges::default();
changes.add_receipts(room_id, first_receipt_event);
store.save_changes(&changes).await.unwrap();
store.save_changes(&changes).await.expect("writing changes fauked");
assert!(store
.get_user_room_receipt_event(room_id, ReceiptType::Read, user_id())
.await
.unwrap()
.is_some(),);
.expect("failed to read user room receipt after save")
.is_some());
assert_eq!(
store
.get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.expect("failed to read user room receipt for 1 after save")
.len(),
1
1,
"Found a wrong number of receipts for 1 after save"
);
assert!(store
.get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.expect("failed to read user room receipt for 2 after save")
.is_empty());
let mut changes = StateChanges::default();
changes.add_receipts(room_id, second_receipt_event);
store.save_changes(&changes).await.unwrap();
store.save_changes(&changes).await.expect("Saving works");
assert!(store
.get_user_room_receipt_event(room_id, ReceiptType::Read, user_id())
.await
.unwrap()
.expect("Getting user room receipts failed")
.is_some());
assert!(store
.get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id)
.await
.unwrap()
.expect("Getting event room receipt events for first event failed")
.is_empty());
assert_eq!(
store
.get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id)
.await
.unwrap()
.expect("Getting event room receipt events for second event failed")
.len(),
1
1,
"Found a wrong number of receipts for second event after save"
);
}
@@ -467,24 +468,24 @@ macro_rules! statestore_integration_tests {
}),
};
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
assert!(store.get_media_content(&request_file).await.unwrap().is_none(), "unexpected media found");
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none(), "media not found");
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.add_media_content(&request_file, content.clone()).await.expect("adding media failed");
assert!(store.get_media_content(&request_file).await.unwrap().is_some(), "media not found though added");
store.remove_media_content(&request_file).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
store.remove_media_content(&request_file).await.expect("removing media failed");
assert!(store.get_media_content(&request_file).await.unwrap().is_none(), "media still there after removing");
store.add_media_content(&request_file, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_some());
store.add_media_content(&request_file, content.clone()).await.expect("adding media again failed");
assert!(store.get_media_content(&request_file).await.unwrap().is_some(), "media not found after adding again");
store.add_media_content(&request_thumbnail, content.clone()).await.unwrap();
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some());
store.add_media_content(&request_thumbnail, content.clone()).await.expect("adding thumbnail failed");
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_some(), "thumbnail not found");
store.remove_media_content_for_uri(uri).await.unwrap();
assert!(store.get_media_content(&request_file).await.unwrap().is_none());
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none());
store.remove_media_content_for_uri(uri).await.expect("removing all media for uri failed");
assert!(store.get_media_content(&request_file).await.unwrap().is_none(), "media wasn't removed");
assert!(store.get_media_content(&request_thumbnail).await.unwrap().is_none(), "thumbnail wasn't removed");
}
#[async_test]
@@ -526,17 +527,17 @@ macro_rules! statestore_integration_tests {
store.remove_room(room_id).await?;
assert_eq!(store.get_room_infos().await?.len(), 0);
assert!(store.get_room_infos().await?.is_empty(), "room is still there");
assert_eq!(store.get_stripped_room_infos().await?.len(), 1);
assert!(store.get_state_event(room_id, StateEventType::RoomName, "").await?.is_none());
assert_eq!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), 0);
assert!(store.get_state_events(room_id, StateEventType::RoomTopic).await?.is_empty(), "still state events found");
assert!(store.get_profile(room_id, user_id).await?.is_none());
assert!(store.get_member_event(room_id, user_id).await?.is_none());
assert_eq!(store.get_user_ids(room_id).await?.len(), 0);
assert_eq!(store.get_invited_user_ids(room_id).await?.len(), 0);
assert_eq!(store.get_joined_user_ids(room_id).await?.len(), 0);
assert_eq!(store.get_users_with_display_name(room_id, "example").await?.len(), 0);
assert!(store.get_user_ids(room_id).await?.is_empty(), "still user ids found");
assert!(store.get_invited_user_ids(room_id).await?.is_empty(), "still invited user ids found");
assert!(store.get_joined_user_ids(room_id).await?.is_empty(), "still joined users found");
assert!(store.get_users_with_display_name(room_id, "example").await?.is_empty(), "still display names found");
assert!(store
.get_room_account_data_event(room_id, RoomAccountDataEventType::Tag)
.await?
@@ -545,18 +546,18 @@ macro_rules! statestore_integration_tests {
.get_user_room_receipt_event(room_id, ReceiptType::Read, user_id)
.await?
.is_none());
assert_eq!(
assert!(
store
.get_event_room_receipt_events(room_id, ReceiptType::Read, first_receipt_event_id())
.await?
.len(),
0
.is_empty(),
"still event recepts in the store"
);
store.remove_room(stripped_room_id).await?;
assert_eq!(store.get_room_infos().await?.len(), 0);
assert_eq!(store.get_stripped_room_infos().await?.len(), 0);
assert!(store.get_room_infos().await?.is_empty(), "still room info found");
assert!(store.get_stripped_room_infos().await?.is_empty(), "still stripped room info found");
Ok(())
}
@@ -567,7 +568,7 @@ macro_rules! statestore_integration_tests {
let room_id = room_id!("!SVkFJHzfwvuaIEawgC:localhost");
// Before the first sync the timeline should be empty
assert!(store.room_timeline(room_id).await.unwrap().is_none());
assert!(store.room_timeline(room_id).await.expect("failed to read timeline").is_none(), "TL wasn't empty");
// Add sync response
let sync = SyncResponse::try_from_http_response(
@@ -700,10 +701,17 @@ macro_rules! statestore_integration_tests {
let timeline = timeline_iter.collect::<Vec<StoreResult<SyncRoomEvent>>>().await;
assert!(timeline
.into_iter()
.zip(stored_events.iter())
.all(|(a, b)| a.unwrap().event_id() == b.event_id()));
let expected: Vec<Box<EventId>> = stored_events.iter().map(|a| a.event_id().expect("event id doesn't exist")).collect();
let found: Vec<Box<EventId>> = timeline.iter().map(|a| a.as_ref().expect("object missing").event_id().clone().expect("event id missing")).collect();
for (idx, (a, b)) in timeline
.into_iter()
.zip(stored_events.iter())
.enumerate()
{
assert_eq!(a.expect("not a value").event_id(), b.event_id(), "pos {} not equal - expected: {:#?}, but found {:#?}", idx, expected, found);
}
}
}

View File

@@ -50,9 +50,6 @@ use ruma::{
EventId, MxcUri, RoomId, UserId,
};
#[cfg(feature = "store_key")]
pub mod store_key;
/// BoxStream of owned Types
pub type BoxStream<T> = Pin<Box<dyn futures_util::Stream<Item = T> + Send>>;

View File

@@ -1,299 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Facilities for StateStore implementations to reuse to manage encrypted
//! state keys
use std::convert::TryFrom;
use chacha20poly1305::{
aead::{Aead, Error as EncryptionError, NewAead},
ChaCha20Poly1305, Key, Nonce, XChaCha20Poly1305, XNonce,
};
use hmac::Hmac;
use pbkdf2::pbkdf2;
use rand::{thread_rng, Error as RngError, Fill};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use zeroize::{Zeroize, Zeroizing};
use crate::StoreError;
const VERSION: u8 = 1;
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
const XNONCE_SIZE: usize = 24;
const KDF_SALT_SIZE: usize = 32;
#[cfg(not(test))]
const KDF_ROUNDS: u32 = 200_000;
#[cfg(test)]
const KDF_ROUNDS: u32 = 1000;
/// Local State Error for Store Keys
///
/// provides facilities to directly convert into a `StoreError` thus the most
/// common usage is to `.map_err(StoreError::into)` it.
#[derive(Debug, thiserror::Error)]
pub enum Error {
/// A problem de/serializing the JSON
#[error(transparent)]
Serialization(#[from] serde_json::Error),
/// A problem with en/decrypting
#[error("Error encrypting or decrypting an event {0}")]
Encryption(String),
/// Error generating enough random data for a cryptographic operation
#[error("Error generating enough random data for a cryptographic operation")]
Random(#[from] RngError),
}
#[allow(clippy::from_over_into)]
impl Into<StoreError> for Error {
fn into(self) -> StoreError {
match self {
Error::Serialization(e) => StoreError::Json(e),
Error::Encryption(e) => StoreError::Encryption(e),
Error::Random(_) => StoreError::Encryption(self.to_string()),
}
}
}
impl From<EncryptionError> for Error {
fn from(e: EncryptionError) -> Self {
Error::Encryption(e.to_string())
}
}
/// Holding the data of the encrypted event
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct EncryptedEvent {
version: u8,
ciphertext: Vec<u8>,
nonce: Vec<u8>,
}
/// Version specific info for the key derivation method that is used.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
enum KdfInfo {
Pbkdf2ToChaCha20Poly1305 {
/// The number of PBKDF rounds that were used when deriving the store
/// key.
rounds: u32,
/// The salt that was used when the passphrase was expanded into a store
/// key.
kdf_salt: Vec<u8>,
},
}
/// Version specific info for encryption method that is used to encrypt our
/// store key.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
enum CipherTextInfo {
ChaCha20Poly1305 {
/// The nonce that was used to encrypt the ciphertext.
nonce: Vec<u8>,
/// The encrypted store key.
ciphertext: Vec<u8>,
},
}
/// An encrypted version of our store key, this can be safely stored in a
/// database.
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct EncryptedStoreKey {
/// Info about the key derivation method that was used to expand the
/// passphrase into an encryption key.
kdf_info: KdfInfo,
/// The ciphertext with it's accompanying additional data that is needed to
/// decrypt the store key.
ciphertext_info: CipherTextInfo,
}
/// A store key that can be used to encrypt entries in the store.
#[derive(Debug, Zeroize, PartialEq)]
pub struct StoreKey {
inner: Vec<u8>,
}
impl TryFrom<Vec<u8>> for StoreKey {
type Error = ();
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
if value.len() != KEY_SIZE {
Err(())
} else {
Ok(Self { inner: value })
}
}
}
impl StoreKey {
/// Generate a new random store key.
pub fn new() -> Result<Self, Error> {
let mut key = vec![0u8; KEY_SIZE];
let mut rng = thread_rng();
key.try_fill(&mut rng)?;
Ok(Self { inner: key })
}
/// Expand the given passphrase into a KEY_SIZE long key.
fn expand_key(passphrase: &str, salt: &[u8], rounds: u32) -> Zeroizing<Vec<u8>> {
let mut key = Zeroizing::from(vec![0u8; KEY_SIZE]);
pbkdf2::<Hmac<Sha256>>(passphrase.as_bytes(), salt, rounds, &mut *key);
key
}
/// Get the store key.
fn key(&self) -> &Key {
Key::from_slice(&self.inner)
}
/// Encrypt and export our store key using the given passphrase.
///
/// # Arguments
///
/// * `passphrase` - The passphrase that should be used to encrypt the
/// store key.
pub fn export(&self, passphrase: &str) -> Result<EncryptedStoreKey, Error> {
let mut rng = thread_rng();
let mut salt = vec![0u8; KDF_SALT_SIZE];
salt.try_fill(&mut rng)?;
let key = StoreKey::expand_key(passphrase, &salt, KDF_ROUNDS);
let key = Key::from_slice(key.as_ref());
let cipher = ChaCha20Poly1305::new(key);
let mut nonce = vec![0u8; NONCE_SIZE];
nonce.try_fill(&mut rng)?;
let ciphertext =
cipher.encrypt(Nonce::from_slice(nonce.as_ref()), self.inner.as_slice())?;
Ok(EncryptedStoreKey {
kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: KDF_ROUNDS, kdf_salt: salt },
ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext },
})
}
fn get_nonce() -> Result<Vec<u8>, RngError> {
let mut nonce = vec![0u8; XNONCE_SIZE];
let mut rng = thread_rng();
nonce.try_fill(&mut rng)?;
Ok(nonce)
}
/// Encrypt the given Event after serializing it with serde_json
pub fn encrypt(&self, event: &impl Serialize) -> Result<EncryptedEvent, Error> {
let event = serde_json::to_vec(event)?;
let nonce = StoreKey::get_nonce()?;
let cipher = XChaCha20Poly1305::new(self.key());
let xnonce = XNonce::from_slice(&nonce);
let ciphertext = cipher.encrypt(xnonce, event.as_ref())?;
Ok(EncryptedEvent { version: VERSION, ciphertext, nonce })
}
/// Decrypt the given encrypted event back into the inner
pub fn decrypt<T: for<'b> Deserialize<'b>>(&self, event: EncryptedEvent) -> Result<T, Error> {
if event.version != VERSION {
return Err(Error::Encryption(
"Error decrypting: Unknown ciphertext version".to_owned(),
));
}
let cipher = XChaCha20Poly1305::new(self.key());
let nonce = XNonce::from_slice(&event.nonce);
let plaintext = cipher.decrypt(nonce, event.ciphertext.as_ref())?;
Ok(serde_json::from_slice(&plaintext)?)
}
/// Restore a store key from an encrypted export.
///
/// # Arguments
///
/// * `passphrase` - The passphrase that should be used to encrypt the
/// store key.
///
/// * `encrypted` - The exported and encrypted version of the store key.
pub fn import(passphrase: &str, encrypted: EncryptedStoreKey) -> Result<Self, EncryptionError> {
let key = match encrypted.kdf_info {
KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds, kdf_salt } => {
Self::expand_key(passphrase, &kdf_salt, rounds)
}
};
let key = Key::from_slice(key.as_ref());
let decrypted = match encrypted.ciphertext_info {
CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext } => {
let cipher = ChaCha20Poly1305::new(key);
let nonce = Nonce::from_slice(&nonce);
cipher.decrypt(nonce, ciphertext.as_ref())?
}
};
Ok(Self { inner: decrypted })
}
}
#[cfg(test)]
mod tests {
use serde_json::{json, Value};
use super::StoreKey;
#[test]
fn generating() {
StoreKey::new().unwrap();
}
#[test]
fn encrypting() {
let passphrase = "it's a secret to everybody";
let store_key = StoreKey::new().unwrap();
let encrypted = store_key.export(passphrase).unwrap();
let decrypted = StoreKey::import(passphrase, encrypted).unwrap();
assert_eq!(store_key, decrypted);
}
#[test]
fn encrypting_events() {
let event = json!({
"content": {
"body": "Bee Gees - Stayin' Alive",
"info": {
"duration": 2140786,
"mimetype": "audio/mpeg",
"size": 1563685
},
"msgtype": "m.audio",
"url": "mxc://example.org/ffed755USFFxlgbQYZGtryd"
},
});
let store_key = StoreKey::new().unwrap();
let encrypted = store_key.encrypt(&event).unwrap();
let decrypted: Value = store_key.decrypt(encrypted).unwrap();
assert_eq!(event, decrypted);
}
}

View File

@@ -11,21 +11,22 @@ encryption = ["matrix-sdk-base/encryption", "matrix-sdk-crypto"]
default-target = "wasm32-unknown-unknown"
[dependencies]
matrix-sdk-base = { path = "../matrix-sdk-base", features = ["store_key"] }
matrix-sdk-crypto = { path = "../matrix-sdk-crypto", optional = true }
matrix-sdk-common = { path = "../matrix-sdk-common" }
matrix-sdk-store-encryption = { path = "../matrix-sdk-store-encryption" }
thiserror = "1.0.25"
anyhow = "1"
base64 = { version = "0.13.0" }
dashmap = "5.1.0"
futures-util = { version = "0.3.15", default-features = false }
indexed_db_futures = { version = "0.2.0" }
wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"] }
web-sys = { version = "0.3.35", features = ["IdbKeyRange"] }
matrix-sdk-base = { path = "../matrix-sdk-base" }
matrix-sdk-common = { path = "../matrix-sdk-common" }
matrix-sdk-store-encryption = { path = "../matrix-sdk-store-encryption" }
serde = { version = "1.0.126" }
serde_json = "1.0.64"
dashmap = "5.1.0"
thiserror = "1.0.25"
tracing = "0.1.26"
wasm-bindgen = { version = "0.2.74", features = ["serde-serialize"] }
web-sys = { version = "0.3.35", features = ["IdbKeyRange"] }
matrix-sdk-crypto = { path = "../matrix-sdk-crypto", optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
@@ -36,4 +37,5 @@ getrandom = { version = "0.2", features = ["js"] }
matrix-sdk-base = { path = "../matrix-sdk-base", features = ["testing"] }
matrix-sdk-crypto = { path = "../matrix-sdk-crypto", features = ["testing"] }
matrix-sdk-test = { path = "../matrix-sdk-test" }
uuid = "0.8"
wasm-bindgen-test = "0.3.24"

View File

@@ -1,10 +1,12 @@
#![allow(dead_code)]
use base64::{encode_config as base64_encode, STANDARD_NO_PAD};
use matrix_sdk_base::ruma::events::{
GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
};
use matrix_sdk_common::ruma::{
receipt::ReceiptType, DeviceId, EventId, MxcUri, RoomId, TransactionId, UserId,
};
use matrix_sdk_store_encryption::StoreCipher;
use wasm_bindgen::JsValue;
use web_sys::IdbKeyRange;
@@ -36,7 +38,46 @@ pub trait SafeEncode {
/// encode self into a JsValue, internally using `as_encoded_string`
/// to escape the value of self.
fn encode(&self) -> JsValue {
JsValue::from(self.as_encoded_string())
self.as_encoded_string().into()
}
/// encode self into a JsValue, internally using `as_secure_string`
/// to escape the value of self,
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> JsValue {
self.as_secure_string(table_name, store_cipher).into()
}
/// encode self securely for the given tablename with the given
/// `store_cipher` hash_key, returns the value as a base64 encoded
/// string without any padding.
fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String {
base64_encode(
store_cipher.hash_key(table_name, self.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
)
}
/// encode self into a JsValue, internally using `as_encoded_string`
/// to escape the value of self, and append the given counter
fn encode_with_counter(&self, i: usize) -> JsValue {
format!("{}{}{:0000000x}", self.as_encoded_string(), KEY_SEPARATOR, i).into()
}
/// encode self into a JsValue, internally using `as_secure_string`
/// to escape the value of self, and append the given counter
fn encode_with_counter_secure(
&self,
table_name: &str,
store_cipher: &StoreCipher,
i: usize,
) -> JsValue {
format!(
"{}{}{:0000000x}",
self.as_secure_string(table_name, store_cipher),
KEY_SEPARATOR,
i
)
.into()
}
/// Encode self into a IdbKeyRange for searching all keys that are
@@ -50,6 +91,19 @@ pub trait SafeEncode {
)
.map_err(|e| e.as_string().unwrap_or_else(|| "Creating key range failed".to_owned()))
}
fn encode_to_range_secure(
&self,
table_name: &str,
store_cipher: &StoreCipher,
) -> Result<IdbKeyRange, String> {
let key = self.as_secure_string(table_name, store_cipher);
IdbKeyRange::bound(
&JsValue::from([&key, KEY_SEPARATOR].concat()),
&JsValue::from([&key, RANGE_END].concat()),
)
.map_err(|e| e.as_string().unwrap_or_else(|| "Creating key range failed".to_owned()))
}
}
/// Implement SafeEncode for tuple of two elements, separating the escaped
@@ -62,6 +116,21 @@ where
fn as_encoded_string(&self) -> String {
[&self.0.as_encoded_string(), KEY_SEPARATOR, &self.1.as_encoded_string()].concat()
}
fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String {
[
&base64_encode(
store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
KEY_SEPARATOR,
&base64_encode(
store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
]
.concat()
}
}
/// Implement SafeEncode for tuple of three elements, separating the escaped
@@ -82,6 +151,26 @@ where
]
.concat()
}
fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String {
[
&base64_encode(
store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
KEY_SEPARATOR,
&base64_encode(
store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
KEY_SEPARATOR,
&base64_encode(
store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
]
.concat()
}
}
/// Implement SafeEncode for tuple of four elements, separating the escaped
@@ -105,6 +194,31 @@ where
]
.concat()
}
fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String {
[
&base64_encode(
store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
KEY_SEPARATOR,
&base64_encode(
store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
KEY_SEPARATOR,
&base64_encode(
store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
KEY_SEPARATOR,
&base64_encode(
store_cipher.hash_key(table_name, self.3.as_encoded_string().as_bytes()),
STANDARD_NO_PAD,
),
]
.concat()
}
}
impl SafeEncode for String {
@@ -195,4 +309,8 @@ impl SafeEncode for usize {
fn as_encoded_string(&self) -> String {
self.to_string()
}
fn as_secure_string(&self, _table_name: &str, _store_cipher: &StoreCipher) -> String {
self.to_string()
}
}

View File

@@ -20,10 +20,7 @@ use indexed_db_futures::prelude::*;
use matrix_sdk_base::{
deserialized_responses::{MemberEvent, SyncRoomEvent},
media::{MediaRequest, UniqueKey},
store::{
store_key::{self, EncryptedEvent, StoreKey},
BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError,
},
store::{BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError},
RoomInfo,
};
use matrix_sdk_common::{
@@ -46,24 +43,23 @@ use matrix_sdk_common::{
EventId, MxcUri, RoomId, RoomVersionId, UserId,
},
};
use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher};
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
use wasm_bindgen::JsValue;
use web_sys::IdbKeyRange;
use crate::safe_encode::SafeEncode;
#[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType {
Unencrypted,
Encrypted(store_key::EncryptedStoreKey),
}
#[derive(Clone, Serialize, Deserialize)]
struct StoreKeyWrapper(Vec<u8>);
#[derive(Debug, thiserror::Error)]
pub enum SerializationError {
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Encryption(#[from] store_key::Error),
Encryption(#[from] EncryptionError),
#[error("DomException {name} ({code}): {message}")]
DomException { name: String, message: String, code: u16 },
#[error(transparent)]
@@ -86,9 +82,17 @@ impl From<SerializationError> for StoreError {
SerializationError::Json(e) => StoreError::Json(e),
SerializationError::StoreError(e) => e,
SerializationError::Encryption(e) => match e {
store_key::Error::Random(e) => StoreError::Encryption(e.to_string()),
store_key::Error::Serialization(e) => StoreError::Json(e),
store_key::Error::Encryption(e) => StoreError::Encryption(e),
EncryptionError::Random(e) => StoreError::Encryption(e.to_string()),
EncryptionError::Serialization(e) => StoreError::Json(e),
EncryptionError::Encryption(e) => StoreError::Encryption(e.to_string()),
EncryptionError::Version(found, expected) => StoreError::Encryption(format!(
"Bad Database Encryption Version: expected {} found {}",
expected, found
)),
EncryptionError::Length(found, expected) => StoreError::Encryption(format!(
"The database key an invalid length: expected {} found {}",
expected, found
)),
},
_ => StoreError::Backend(anyhow!(e)),
}
@@ -138,7 +142,7 @@ mod KEYS {
pub struct IndexeddbStore {
name: String,
pub(crate) inner: IdbDatabase,
store_key: Option<StoreKey>,
store_cipher: Option<StoreCipher>,
}
impl std::fmt::Debug for IndexeddbStore {
@@ -150,7 +154,7 @@ impl std::fmt::Debug for IndexeddbStore {
type Result<A, E = SerializationError> = std::result::Result<A, E>;
impl IndexeddbStore {
async fn open_helper(name: String, store_key: Option<StoreKey>) -> Result<Self> {
async fn open_helper(name: String, store_cipher: Option<StoreCipher>) -> Result<Self> {
// Open my_db v1
let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?;
db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> {
@@ -193,7 +197,7 @@ impl IndexeddbStore {
let db: IdbDatabase = db_req.into_future().await?;
Ok(Self { name, inner: db, store_key })
Ok(Self { name, inner: db, store_cipher })
}
#[allow(dead_code)]
@@ -205,7 +209,7 @@ impl IndexeddbStore {
Ok(Self::inner_open_with_passphrase(name, passphrase).await?)
}
async fn inner_open_with_passphrase(name: String, passphrase: &str) -> Result<Self> {
pub(crate) async fn inner_open_with_passphrase(name: String, passphrase: &str) -> Result<Self> {
let name = format!("{:0}::matrix-sdk-state", name);
let mut db_req: OpenDbRequest = IdbDatabase::open_u32(&name, 1)?;
@@ -225,33 +229,25 @@ impl IndexeddbStore {
db.transaction_on_one_with_mode("matrix-sdk-state", IdbTransactionMode::Readwrite)?;
let ob = tx.object_store("matrix-sdk-state")?;
let store_key: Option<DatabaseType> = ob
let cipher = if let Some(StoreKeyWrapper(inner)) = ob
.get(&JsValue::from_str(KEYS::STORE_KEY))?
.await?
.map(|k| k.into_serde())
.transpose()?;
let store_key = if let Some(key) = store_key {
if let DatabaseType::Encrypted(k) = key {
StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)?
} else {
return Err(StoreError::UnencryptedStore.into());
}
.map(|v| v.into_serde())
.transpose()?
{
StoreCipher::import(passphrase, &inner)?
} else {
let key = StoreKey::new().map_err::<SerializationError, _>(|e| e.into())?;
let encrypted_key = DatabaseType::Encrypted(
key.export(passphrase).map_err::<SerializationError, _>(|e| e.into())?,
);
let cipher = StoreCipher::new()?;
ob.put_key_val(
&JsValue::from_str(KEYS::STORE_KEY),
&JsValue::from_serde(&encrypted_key)?,
&JsValue::from_serde(&StoreKeyWrapper(cipher.export(passphrase)?))?,
)?;
key
cipher
};
tx.await.into_result()?;
IndexeddbStore::open_helper(name, Some(store_key)).await
IndexeddbStore::open_helper(name, Some(cipher)).await
}
pub async fn open_with_name(name: String) -> StoreResult<Self> {
@@ -262,8 +258,8 @@ impl IndexeddbStore {
&self,
event: &impl Serialize,
) -> std::result::Result<JsValue, SerializationError> {
Ok(match self.store_key {
Some(ref key) => JsValue::from_serde(&key.encrypt(event)?)?,
Ok(match self.store_cipher {
Some(ref cipher) => JsValue::from_serde(&cipher.encrypt_value_typed(event)?)?,
None => JsValue::from_serde(event)?,
})
}
@@ -272,15 +268,47 @@ impl IndexeddbStore {
&self,
event: JsValue,
) -> std::result::Result<T, SerializationError> {
match self.store_key {
Some(ref key) => {
let encrypted: EncryptedEvent = event.into_serde()?;
Ok(key.decrypt(encrypted)?)
}
match self.store_cipher {
Some(ref cipher) => Ok(cipher.decrypt_value_typed(event.into_serde()?)?),
None => Ok(event.into_serde()?),
}
}
fn encode_key<T>(&self, table_name: &str, key: T) -> JsValue
where
T: SafeEncode,
{
match self.store_cipher {
Some(ref cipher) => key.encode_secure(table_name, cipher),
None => key.encode(),
}
}
fn encode_to_range<T>(
&self,
table_name: &str,
key: T,
) -> Result<IdbKeyRange, SerializationError>
where
T: SafeEncode,
{
match self.store_cipher {
Some(ref cipher) => key.encode_to_range_secure(table_name, cipher),
None => key.encode_to_range(),
}
.map_err(|e| SerializationError::StoreError(StoreError::Backend(anyhow!(e))))
}
fn encode_key_with_counter<T>(&self, table_name: &str, key: &T, i: usize) -> JsValue
where
T: SafeEncode,
{
match self.store_cipher {
Some(ref cipher) => key.encode_with_counter_secure(table_name, cipher, i),
None => key.encode_with_counter(i),
}
}
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
let tx = self
.inner
@@ -288,7 +316,10 @@ impl IndexeddbStore {
let obj = tx.object_store(KEYS::SESSION)?;
obj.put_key_val(&(KEYS::FILTER, filter_name).encode(), &JsValue::from_str(filter_id))?;
obj.put_key_val(
&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)),
&JsValue::from_str(filter_id),
)?;
tx.await.into_result()?;
@@ -300,7 +331,7 @@ impl IndexeddbStore {
.inner
.transaction_on_one_with_mode(KEYS::SESSION, IdbTransactionMode::Readonly)?
.object_store(KEYS::SESSION)?
.get(&(KEYS::FILTER, filter_name).encode())?
.get(&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)))?
.await?
.and_then(|f| f.as_string()))
}
@@ -373,7 +404,7 @@ impl IndexeddbStore {
let store = tx.object_store(KEYS::DISPLAY_NAMES)?;
for (room_id, ambiguity_maps) in &changes.ambiguity_maps {
for (display_name, map) in ambiguity_maps {
let key = (room_id, display_name).encode();
let key = self.encode_key(KEYS::DISPLAY_NAMES, (room_id, display_name));
store.put_key_val(&key, &self.serialize_event(&map)?)?;
}
@@ -383,7 +414,10 @@ impl IndexeddbStore {
if !changes.account_data.is_empty() {
let store = tx.object_store(KEYS::ACCOUNT_DATA)?;
for (event_type, event) in &changes.account_data {
store.put_key_val(&event_type.encode(), &self.serialize_event(&event)?)?;
store.put_key_val(
&self.encode_key(KEYS::ACCOUNT_DATA, event_type),
&self.serialize_event(&event)?,
)?;
}
}
@@ -391,7 +425,7 @@ impl IndexeddbStore {
let store = tx.object_store(KEYS::ROOM_ACCOUNT_DATA)?;
for (room, events) in &changes.room_account_data {
for (event_type, event) in events {
let key = (room, event_type).encode();
let key = self.encode_key(KEYS::ROOM_ACCOUNT_DATA, (room, event_type));
store.put_key_val(&key, &self.serialize_event(&event)?)?;
}
}
@@ -402,7 +436,7 @@ impl IndexeddbStore {
for (room, event_types) in &changes.state {
for (event_type, events) in event_types {
for (state_key, event) in events {
let key = (room, event_type, state_key).encode();
let key = self.encode_key(KEYS::ROOM_STATE, (room, event_type, state_key));
store.put_key_val(&key, &self.serialize_event(&event)?)?;
}
}
@@ -412,21 +446,30 @@ impl IndexeddbStore {
if !changes.room_infos.is_empty() {
let store = tx.object_store(KEYS::ROOM_INFOS)?;
for (room_id, room_info) in &changes.room_infos {
store.put_key_val(&room_id.encode(), &self.serialize_event(&room_info)?)?;
store.put_key_val(
&self.encode_key(KEYS::ROOM_INFOS, room_id),
&self.serialize_event(&room_info)?,
)?;
}
}
if !changes.presence.is_empty() {
let store = tx.object_store(KEYS::PRESENCE)?;
for (sender, event) in &changes.presence {
store.put_key_val(&sender.encode(), &self.serialize_event(&event)?)?;
store.put_key_val(
&self.encode_key(KEYS::PRESENCE, sender),
&self.serialize_event(&event)?,
)?;
}
}
if !changes.stripped_room_infos.is_empty() {
let store = tx.object_store(KEYS::STRIPPED_ROOM_INFOS)?;
for (room_id, info) in &changes.stripped_room_infos {
store.put_key_val(&room_id.encode(), &self.serialize_event(&info)?)?;
store.put_key_val(
&self.encode_key(KEYS::STRIPPED_ROOM_INFOS, room_id),
&self.serialize_event(&info)?,
)?;
}
}
@@ -434,7 +477,7 @@ impl IndexeddbStore {
let store = tx.object_store(KEYS::STRIPPED_MEMBERS)?;
for (room, events) in &changes.stripped_members {
for event in events.values() {
let key = (room, &event.state_key).encode();
let key = self.encode_key(KEYS::STRIPPED_MEMBERS, (room, &event.state_key));
store.put_key_val(&key, &self.serialize_event(&event)?)?;
}
}
@@ -445,7 +488,8 @@ impl IndexeddbStore {
for (room, event_types) in &changes.stripped_state {
for (event_type, events) in event_types {
for (state_key, event) in events {
let key = (room, event_type, state_key).encode();
let key = self
.encode_key(KEYS::STRIPPED_ROOM_STATE, (room, event_type, state_key));
store.put_key_val(&key, &self.serialize_event(&event)?)?;
}
}
@@ -462,27 +506,39 @@ impl IndexeddbStore {
let members = tx.object_store(KEYS::MEMBERS)?;
for event in events.values() {
let key = (room, &event.state_key).encode();
let key = (room, &event.state_key);
match event.content.membership {
MembershipState::Join => {
joined.put_key_val_owned(&key, &event.state_key.encode())?;
invited.delete(&key)?;
joined.put_key_val_owned(
&self.encode_key(KEYS::JOINED_USER_IDS, key),
&self.serialize_event(&event.state_key)?,
)?;
invited.delete(&self.encode_key(KEYS::INVITED_USER_IDS, key))?;
}
MembershipState::Invite => {
invited.put_key_val_owned(&key, &event.state_key.encode())?;
joined.delete(&key)?;
invited.put_key_val_owned(
&self.encode_key(KEYS::INVITED_USER_IDS, key),
&self.serialize_event(&event.state_key)?,
)?;
joined.delete(&self.encode_key(KEYS::JOINED_USER_IDS, key))?;
}
_ => {
joined.delete(&key)?;
invited.delete(&key)?;
joined.delete(&self.encode_key(KEYS::JOINED_USER_IDS, key))?;
invited.delete(&self.encode_key(KEYS::INVITED_USER_IDS, key))?;
}
}
members.put_key_val_owned(&key, &self.serialize_event(&event)?)?;
members.put_key_val_owned(
&self.encode_key(KEYS::MEMBERS, key),
&self.serialize_event(&event)?,
)?;
if let Some(profile) = profile_changes.and_then(|p| p.get(&event.state_key)) {
profiles_store.put_key_val_owned(&key, &self.serialize_event(&profile)?)?;
profiles_store.put_key_val_owned(
&self.encode_key(KEYS::PROFILES, key),
&self.serialize_event(&profile)?,
)?;
}
}
}
@@ -496,15 +552,20 @@ impl IndexeddbStore {
for (event_id, receipts) in &content.0 {
for (receipt_type, receipts) in receipts {
for (user_id, receipt) in receipts {
let key = (room, receipt_type, user_id).encode();
let key = self.encode_key(
KEYS::ROOM_USER_RECEIPTS,
(room, receipt_type, user_id),
);
if let Some((old_event, _)) =
room_user_receipts.get(&key)?.await?.and_then(|f| {
self.deserialize_event::<(Box<EventId>, Receipt)>(f).ok()
})
{
room_event_receipts
.delete(&(room, receipt_type, &old_event, user_id).encode())?;
room_event_receipts.delete(&self.encode_key(
KEYS::ROOM_EVENT_RECEIPTS,
(room, receipt_type, &old_event, user_id),
))?;
}
room_user_receipts
@@ -512,8 +573,11 @@ impl IndexeddbStore {
// Add the receipt to the room event receipts
room_event_receipts.put_key_val(
&(room, receipt_type, event_id, user_id).encode(),
&self.serialize_event(&receipt)?,
&self.encode_key(
KEYS::ROOM_EVENT_RECEIPTS,
(room, receipt_type, event_id, user_id),
),
&self.serialize_event(&(user_id, receipt))?,
)?;
}
}
@@ -533,18 +597,19 @@ impl IndexeddbStore {
} else {
info!("Save new timeline batch from messages response for {}", room_id);
}
let room_key = room_id.encode();
let metadata: Option<TimelineMetadata> = if timeline.limited {
info!(
"Delete stored timeline for {} because the sync response was limited",
room_id
);
let range = room_id.encode_to_range().map_err(StoreError::Codec)?;
let stores =
&[&timeline_store, &timeline_metadata_store, &event_id_to_position_store];
for store in stores {
let stores = &[
(KEYS::ROOM_TIMELINE, &timeline_store),
(KEYS::ROOM_TIMELINE_METADATA, &timeline_metadata_store),
(KEYS::ROOM_EVENT_ID_TO_POSITION, &event_id_to_position_store),
];
for (table_name, store) in stores {
let range = self.encode_to_range(table_name, room_id)?;
for key in store.get_all_keys_with_key(&range)?.await?.iter() {
store.delete(&key)?;
}
@@ -553,7 +618,7 @@ impl IndexeddbStore {
None
} else {
let metadata: Option<TimelineMetadata> = timeline_metadata_store
.get(&room_key)?
.get(&self.encode_key(KEYS::ROOM_TIMELINE_METADATA, room_id))?
.await?
.map(|v| v.into_serde())
.transpose()?;
@@ -570,7 +635,10 @@ impl IndexeddbStore {
let mut delete_timeline = false;
for event in &timeline.events {
if let Some(event_id) = event.event_id() {
let event_key = (room_id, &event_id).encode();
let event_key = self.encode_key(
KEYS::ROOM_EVENT_ID_TO_POSITION,
(room_id, event_id),
);
if event_id_to_position_store
.count_with_key_owned(event_key)?
.await?
@@ -588,13 +656,13 @@ impl IndexeddbStore {
room_id
);
let range = room_id.encode_to_range().map_err(StoreError::Codec)?;
let stores = &[
&timeline_store,
&timeline_metadata_store,
&event_id_to_position_store,
(KEYS::ROOM_TIMELINE, &timeline_store),
(KEYS::ROOM_TIMELINE_METADATA, &timeline_metadata_store),
(KEYS::ROOM_EVENT_ID_TO_POSITION, &event_id_to_position_store),
];
for store in stores {
for (table_name, store) in stores {
let range = self.encode_to_range(table_name, room_id)?;
for key in store.get_all_keys_with_key(&range)?.await?.iter() {
store.delete(&key)?;
}
@@ -626,7 +694,7 @@ impl IndexeddbStore {
if timeline.sync {
let room_version = room_infos
.get(&room_key)?
.get(&self.encode_key(KEYS::ROOM_INFOS, room_id))?
.await?
.map(|r| self.deserialize_event::<RoomInfo>(r))
.transpose()?
@@ -646,7 +714,10 @@ impl IndexeddbStore {
),
)) = event.event.deserialize()
{
let redacts_key = (room_id, &redaction.redacts).encode();
let redacts_key = self.encode_key(
KEYS::ROOM_EVENT_ID_TO_POSITION,
(room_id, &redaction.redacts),
);
if let Some(position_key) =
event_id_to_position_store.get_owned(redacts_key)?.await?
{
@@ -673,22 +744,32 @@ impl IndexeddbStore {
}
metadata.start_position -= 1;
let key = (room_id, &metadata.start_position).encode();
let key = self.encode_key_with_counter(
KEYS::ROOM_TIMELINE,
room_id,
metadata.start_position,
);
// Only add event with id to the position map
if let Some(event_id) = event.event_id() {
let event_key = (room_id, &event_id).encode();
let event_key = self
.encode_key(KEYS::ROOM_EVENT_ID_TO_POSITION, (room_id, event_id));
event_id_to_position_store.put_key_val(&event_key, &key)?;
}
timeline_store.put_key_val_owned(key, &self.serialize_event(&event)?)?;
timeline_store.put_key_val_owned(&key, &self.serialize_event(&event)?)?;
}
} else {
for event in &timeline.events {
metadata.end_position += 1;
let key = (room_id, &metadata.end_position).encode();
let key = self.encode_key_with_counter(
KEYS::ROOM_TIMELINE,
room_id,
metadata.end_position,
);
// Only add event with id to the position map
if let Some(event_id) = event.event_id() {
let event_key = (room_id, &event_id).encode();
let event_key = self
.encode_key(KEYS::ROOM_EVENT_ID_TO_POSITION, (room_id, event_id));
event_id_to_position_store.put_key_val(&event_key, &key)?;
}
@@ -696,8 +777,10 @@ impl IndexeddbStore {
}
}
timeline_metadata_store
.put_key_val_owned(room_key, &JsValue::from_serde(&metadata)?)?;
timeline_metadata_store.put_key_val_owned(
&self.encode_key(KEYS::ROOM_TIMELINE_METADATA, room_id),
&JsValue::from_serde(&metadata)?,
)?;
}
}
@@ -708,7 +791,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::PRESENCE, IdbTransactionMode::Readonly)?
.object_store(KEYS::PRESENCE)?
.get(&user_id.encode())?
.get(&self.encode_key(KEYS::PRESENCE, user_id))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
@@ -723,7 +806,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::ROOM_STATE, IdbTransactionMode::Readonly)?
.object_store(KEYS::ROOM_STATE)?
.get(&(room_id, &event_type, state_key).encode())?
.get(&self.encode_key(KEYS::ROOM_STATE, (room_id, event_type, state_key)))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
@@ -734,7 +817,7 @@ impl IndexeddbStore {
room_id: &RoomId,
event_type: StateEventType,
) -> Result<Vec<Raw<AnySyncStateEvent>>> {
let range = (room_id, &event_type).encode_to_range().map_err(StoreError::Codec)?;
let range = self.encode_to_range(KEYS::ROOM_STATE, (room_id, event_type))?;
Ok(self
.inner
.transaction_on_one_with_mode(KEYS::ROOM_STATE, IdbTransactionMode::Readonly)?
@@ -754,7 +837,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::PROFILES, IdbTransactionMode::Readonly)?
.object_store(KEYS::PROFILES)?
.get(&(room_id, user_id).encode())?
.get(&self.encode_key(KEYS::PROFILES, (room_id, user_id)))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
@@ -768,31 +851,19 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::MEMBERS, IdbTransactionMode::Readonly)?
.object_store(KEYS::MEMBERS)?
.get(&(room_id, state_key).encode())?
.get(&self.encode_key(KEYS::MEMBERS, (room_id, state_key)))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
}
pub async fn get_user_ids_stream(&self, room_id: &RoomId) -> Result<Vec<Box<UserId>>> {
let range = room_id.encode_to_range().map_err(StoreError::Codec)?;
let skip = room_id.as_encoded_string().len() + 1;
Ok(self
.inner
.transaction_on_one_with_mode(KEYS::MEMBERS, IdbTransactionMode::Readonly)?
.object_store(KEYS::MEMBERS)?
.get_all_keys_with_key(&range)?
.await?
.iter()
.filter_map(|key| match key.as_string() {
Some(k) => UserId::parse(&k[skip..]).ok(),
_ => None,
})
.collect::<Vec<_>>())
Ok([self.get_invited_user_ids(room_id).await?, self.get_joined_user_ids(room_id).await?]
.concat())
}
pub async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<Box<UserId>>> {
let range = room_id.encode_to_range().map_err(StoreError::Codec)?;
let range = self.encode_to_range(KEYS::INVITED_USER_IDS, room_id)?;
let entries = self
.inner
.transaction_on_one_with_mode(KEYS::INVITED_USER_IDS, IdbTransactionMode::Readonly)?
@@ -807,7 +878,7 @@ impl IndexeddbStore {
}
pub async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<Box<UserId>>> {
let range = room_id.encode_to_range().map_err(StoreError::Codec)?;
let range = self.encode_to_range(KEYS::JOINED_USER_IDS, room_id)?;
Ok(self
.inner
.transaction_on_one_with_mode(KEYS::JOINED_USER_IDS, IdbTransactionMode::Readonly)?
@@ -855,7 +926,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::DISPLAY_NAMES, IdbTransactionMode::Readonly)?
.object_store(KEYS::DISPLAY_NAMES)?
.get(&(room_id, display_name).encode())?
.get(&self.encode_key(KEYS::DISPLAY_NAMES, (room_id, display_name)))?
.await?
.map(|f| {
self.deserialize_event::<BTreeSet<Box<UserId>>>(f)
@@ -871,7 +942,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::ACCOUNT_DATA, IdbTransactionMode::Readonly)?
.object_store(KEYS::ACCOUNT_DATA)?
.get(&JsValue::from_str(&event_type.to_string()))?
.get(&self.encode_key(KEYS::ACCOUNT_DATA, event_type))?
.await?
.map(|f| self.deserialize_event(f).map_err::<SerializationError, _>(|e| e))
.transpose()
@@ -885,7 +956,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::ROOM_ACCOUNT_DATA, IdbTransactionMode::Readonly)?
.object_store(KEYS::ROOM_ACCOUNT_DATA)?
.get(&(room_id.as_str(), &event_type.to_string()).encode())?
.get(&self.encode_key(KEYS::ROOM_ACCOUNT_DATA, (room_id, event_type)))?
.await?
.map(|f| self.deserialize_event(f).map_err::<SerializationError, _>(|e| e))
.transpose()
@@ -900,7 +971,7 @@ impl IndexeddbStore {
self.inner
.transaction_on_one_with_mode(KEYS::ROOM_USER_RECEIPTS, IdbTransactionMode::Readonly)?
.object_store(KEYS::ROOM_USER_RECEIPTS)?
.get(&(room_id.as_str(), receipt_type.as_ref(), user_id.as_str()).encode())?
.get(&self.encode_key(KEYS::ROOM_USER_RECEIPTS, (room_id, receipt_type, user_id)))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
@@ -912,38 +983,25 @@ impl IndexeddbStore {
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(Box<UserId>, Receipt)>> {
let key = (room_id, &receipt_type, event_id);
let prefix_len = key.as_encoded_string().len() + 1;
let range = key.encode_to_range().map_err(StoreError::Codec)?;
let range =
self.encode_to_range(KEYS::ROOM_EVENT_RECEIPTS, (room_id, &receipt_type, event_id))?;
let tx = self.inner.transaction_on_one_with_mode(
KEYS::ROOM_EVENT_RECEIPTS,
IdbTransactionMode::Readonly,
)?;
let store = tx.object_store(KEYS::ROOM_EVENT_RECEIPTS)?;
let mut all = Vec::new();
for k in store.get_all_keys_with_key(&range)?.await?.iter() {
// FIXME: we should probably parallelize this...
let res = store
.get(&k)?
.await?
.ok_or_else(|| StoreError::Codec(format!("no data at {:?}", k)))?;
let u = if let Some(k_str) = k.as_string() {
UserId::parse(&k_str[prefix_len..])
.map_err(|e| StoreError::Codec(format!("{:?}", e)))?
} else {
return Err(StoreError::Codec(format!("{:?}", k)).into());
};
let r = self
.deserialize_event::<Receipt>(res)
.map_err(|e| StoreError::Codec(e.to_string()))?;
all.push((u, r));
}
Ok(all)
Ok(store
.get_all_with_key(&range)?
.await?
.iter()
.filter_map(|f| self.deserialize_event(f).ok())
.collect::<Vec<_>>())
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
let key = (&request.source.unique_key(), &request.format.unique_key()).encode();
let key = self
.encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key()));
let tx =
self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?;
@@ -953,7 +1011,8 @@ impl IndexeddbStore {
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
let key = (&request.source.unique_key(), &request.format.unique_key()).encode();
let key = self
.encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key()));
self.inner
.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readonly)?
.object_store(KEYS::MEDIA)?
@@ -997,7 +1056,8 @@ impl IndexeddbStore {
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
let key = (&request.source.unique_key(), &request.format.unique_key()).encode();
let key = self
.encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key()));
let tx =
self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?;
@@ -1007,7 +1067,7 @@ impl IndexeddbStore {
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let range = uri.encode_to_range().map_err(StoreError::Codec)?;
let range = self.encode_to_range(KEYS::MEDIA, uri)?;
let tx =
self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?;
let store = tx.object_store(KEYS::MEDIA)?;
@@ -1050,14 +1110,13 @@ impl IndexeddbStore {
.inner
.transaction_on_multi_with_mode(&all_stores, IdbTransactionMode::Readwrite)?;
let room_key = room_id.encode();
for store_name in direct_stores {
tx.object_store(store_name)?.delete(&room_key)?;
tx.object_store(store_name)?.delete(&self.encode_key(store_name, room_id))?;
}
let range = room_id.encode_to_range().map_err(StoreError::Codec)?;
for store_name in prefixed_stores {
let store = tx.object_store(store_name)?;
let range = self.encode_to_range(store_name, room_id)?;
for key in store.get_all_keys_with_key(&range)?.await?.iter() {
store.delete(&key)?;
}
@@ -1069,7 +1128,6 @@ impl IndexeddbStore {
&self,
room_id: &RoomId,
) -> Result<Option<(BoxStream<StoreResult<SyncRoomEvent>>, Option<String>)>> {
let key = room_id.encode();
let tx = self.inner.transaction_on_multi_with_mode(
&[KEYS::ROOM_TIMELINE, KEYS::ROOM_TIMELINE_METADATA],
IdbTransactionMode::Readonly,
@@ -1077,16 +1135,23 @@ impl IndexeddbStore {
let timeline = tx.object_store(KEYS::ROOM_TIMELINE)?;
let metadata = tx.object_store(KEYS::ROOM_TIMELINE_METADATA)?;
let metadata: Option<TimelineMetadata> =
metadata.get(&key)?.await?.map(|v| v.into_serde()).transpose()?;
if metadata.is_none() {
info!("No timeline for {} was previously stored", room_id);
return Ok(None);
}
let end_token = metadata.and_then(|m| m.end);
let tlm: TimelineMetadata = match metadata
.get(&self.encode_key(KEYS::ROOM_TIMELINE_METADATA, room_id))?
.await?
.map(|v| v.into_serde())
.transpose()?
{
Some(tl) => tl,
_ => {
info!("No timeline for {} was previously stored", room_id);
return Ok(None);
}
};
let end_token = tlm.end;
#[allow(clippy::needless_collect)]
let timeline: Vec<StoreResult<SyncRoomEvent>> = timeline
.get_all_with_key(&key)?
.get_all_with_key(&self.encode_to_range(KEYS::ROOM_TIMELINE, room_id)?)?
.await?
.iter()
.map(|v| self.deserialize_event(v).map_err(|e| e.into()))
@@ -1280,3 +1345,23 @@ mod tests {
statestore_integration_tests! { integration }
}
#[cfg(test)]
mod encrypted_tests {
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
use matrix_sdk_base::statestore_integration_tests;
use uuid::Uuid;
use super::{IndexeddbStore, Result, StoreCipher};
async fn get_store() -> Result<IndexeddbStore> {
let db_name =
format!("test-state-encrypted-{}", Uuid::new_v4().to_hyphenated().to_string());
let key = StoreCipher::new()?;
Ok(IndexeddbStore::open_helper(db_name, Some(key)).await?)
}
statestore_integration_tests! { integration }
}

View File

@@ -12,7 +12,7 @@ crypto-store = ["matrix-sdk-crypto"]
[dependencies]
futures-core = "0.3.15"
futures-util = { version = "0.3.15", default-features = false }
matrix-sdk-base = { path = "../matrix-sdk-base", features = ["store_key"], optional = true }
matrix-sdk-base = { path = "../matrix-sdk-base", optional = true }
matrix-sdk-common = { path = "../matrix-sdk-common" }
matrix-sdk-crypto = { path = "../matrix-sdk-crypto", optional = true }
matrix-sdk-store-encryption = { path = "../matrix-sdk-store-encryption" }

View File

@@ -24,10 +24,7 @@ use dashmap::DashSet;
use matrix_sdk_common::{
async_trait,
locks::Mutex,
ruma::{
events::{room_key_request::RequestedKeyInfo, secret::request::SecretName},
DeviceId, RoomId, TransactionId, UserId,
},
ruma::{events::room_key_request::RequestedKeyInfo, DeviceId, RoomId, TransactionId, UserId},
};
use matrix_sdk_crypto::{
olm::{
@@ -69,12 +66,20 @@ impl EncodeKey for InboundGroupSession {
fn encode(&self) -> Vec<u8> {
(self.room_id(), self.sender_key(), self.session_id()).encode()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
(self.room_id(), self.sender_key(), self.session_id())
.encode_secure(table_name, store_cipher)
}
}
impl EncodeKey for OutboundGroupSession {
fn encode(&self) -> Vec<u8> {
self.room_id().encode()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
self.room_id().encode_secure(table_name, store_cipher)
}
}
impl EncodeKey for Session {
@@ -85,6 +90,15 @@ impl EncodeKey for Session {
[sender_key.as_bytes(), &[ENCODE_SEPARATOR], session_id.as_bytes(), &[ENCODE_SEPARATOR]]
.concat()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let sender_key =
store_cipher.hash_key(table_name, self.sender_key().to_base64().as_bytes());
let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes());
[sender_key.as_slice(), &[ENCODE_SEPARATOR], session_id.as_slice(), &[ENCODE_SEPARATOR]]
.concat()
}
}
impl EncodeKey for SecretInfo {
@@ -94,31 +108,6 @@ impl EncodeKey for SecretInfo {
SecretInfo::SecretRequest(s) => s.encode(),
}
}
}
impl EncodeKey for RequestedKeyInfo {
fn encode(&self) -> Vec<u8> {
(&self.room_id, &self.sender_key, &self.algorithm, &self.session_id).encode()
}
}
impl EncodeKey for ReadOnlyDevice {
fn encode(&self) -> Vec<u8> {
(self.user_id(), self.device_id()).encode()
}
}
impl EncodeSecureKey for (&UserId, &DeviceId) {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let user_id = store_cipher.hash_key(table_name, self.0.as_bytes());
let device_id = store_cipher.hash_key(table_name, self.1.as_bytes());
[user_id.as_slice(), &[ENCODE_SEPARATOR], device_id.as_slice(), &[ENCODE_SEPARATOR]]
.concat()
}
}
impl EncodeSecureKey for SecretInfo {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
match self {
SecretInfo::KeyRequest(k) => k.encode_secure(table_name, store_cipher),
@@ -127,15 +116,10 @@ impl EncodeSecureKey for SecretInfo {
}
}
impl EncodeSecureKey for SecretName {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let name = store_cipher.hash_key(table_name, self.as_ref().as_bytes());
[name.as_slice(), &[ENCODE_SEPARATOR]].concat()
impl EncodeKey for RequestedKeyInfo {
fn encode(&self) -> Vec<u8> {
(&self.room_id, &self.sender_key, &self.algorithm, &self.session_id).encode()
}
}
impl EncodeSecureKey for RequestedKeyInfo {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let room_id = store_cipher.hash_key(table_name, self.room_id.as_bytes());
let sender_key = store_cipher.hash_key(table_name, self.sender_key.as_bytes());
@@ -156,81 +140,15 @@ impl EncodeSecureKey for RequestedKeyInfo {
}
}
impl EncodeSecureKey for UserId {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let user_id = store_cipher.hash_key(table_name, self.as_bytes());
[user_id.as_slice(), &[0xff]].concat()
impl EncodeKey for ReadOnlyDevice {
fn encode(&self) -> Vec<u8> {
(self.user_id(), self.device_id()).encode()
}
}
impl EncodeSecureKey for ReadOnlyDevice {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
(self.user_id(), self.device_id()).encode_secure(table_name, store_cipher)
}
}
impl EncodeSecureKey for str {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let key = store_cipher.hash_key(table_name, self.as_bytes());
[key.as_slice(), &[ENCODE_SEPARATOR]].concat()
}
}
impl EncodeSecureKey for OutboundGroupSession {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
self.room_id().encode_secure(table_name, store_cipher)
}
}
impl EncodeSecureKey for RoomId {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let room_id = store_cipher.hash_key(table_name, self.as_bytes());
[room_id.as_slice(), &[ENCODE_SEPARATOR]].concat()
}
}
impl EncodeSecureKey for InboundGroupSession {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
(self.room_id(), self.sender_key(), self.session_id())
.encode_secure(table_name, store_cipher)
}
}
impl EncodeSecureKey for (&RoomId, &str, &str) {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let first = store_cipher.hash_key(table_name, self.0.as_bytes());
let second = store_cipher.hash_key(table_name, self.1.as_bytes());
let third = store_cipher.hash_key(table_name, self.2.as_bytes());
[
first.as_slice(),
&[ENCODE_SEPARATOR],
second.as_slice(),
&[ENCODE_SEPARATOR],
third.as_slice(),
&[ENCODE_SEPARATOR],
]
.concat()
}
}
impl EncodeSecureKey for Session {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let sender_key =
store_cipher.hash_key(table_name, self.sender_key().to_base64().as_bytes());
let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes());
[sender_key.as_slice(), &[ENCODE_SEPARATOR], session_id.as_slice(), &[ENCODE_SEPARATOR]]
.concat()
}
}
trait EncodeSecureKey {
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8>;
}
#[derive(Clone, Debug)]
pub struct AccountInfo {
user_id: Arc<UserId>,
@@ -332,11 +250,7 @@ impl SledStore {
}
}
fn encode_key<T: EncodeSecureKey + EncodeKey + ?Sized>(
&self,
table_name: &str,
key: &T,
) -> Vec<u8> {
fn encode_key<T: EncodeKey>(&self, table_name: &str, key: T) -> Vec<u8> {
if let Some(store_cipher) = &*self.store_cipher {
key.encode_secure(table_name, store_cipher).to_vec()
} else {

View File

@@ -1,91 +1,163 @@
use std::{borrow::Cow, ops::Deref};
use matrix_sdk_common::ruma::{
events::{secret::request::SecretName, GlobalAccountDataEventType, StateEventType},
DeviceId, EventId, MxcUri, RoomId, TransactionId, UserId,
events::{
secret::request::SecretName, GlobalAccountDataEventType, RoomAccountDataEventType,
StateEventType,
},
receipt::ReceiptType,
DeviceId, EventEncryptionAlgorithm, EventId, MxcUri, RoomId, TransactionId, UserId,
};
use matrix_sdk_store_encryption::StoreCipher;
pub const ENCODE_SEPARATOR: u8 = 0xff;
pub trait EncodeKey {
fn encode(&self) -> Vec<u8>;
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
unimplemented!()
}
fn encode(&self) -> Vec<u8> {
[self.encode_as_bytes().deref(), &[ENCODE_SEPARATOR]].concat()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
let key = store_cipher.hash_key(table_name, &self.encode_as_bytes());
[key.as_slice(), &[ENCODE_SEPARATOR]].concat()
}
}
impl<T: EncodeKey + ?Sized> EncodeKey for &T {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
T::encode_as_bytes(self)
}
fn encode(&self) -> Vec<u8> {
T::encode(self)
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
T::encode_secure(self, table_name, store_cipher)
}
}
impl<T: EncodeKey + ?Sized> EncodeKey for Box<T> {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
T::encode_as_bytes(self)
}
fn encode(&self) -> Vec<u8> {
T::encode(self)
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
T::encode_secure(self, table_name, store_cipher)
}
}
impl EncodeKey for str {
fn encode(&self) -> Vec<u8> {
[self.as_bytes(), &[ENCODE_SEPARATOR]].concat()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_bytes().into()
}
}
impl EncodeKey for String {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_str().as_bytes().into()
}
}
impl EncodeKey for DeviceId {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_str().as_bytes().into()
}
}
impl EncodeKey for EventId {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_str().as_bytes().into()
}
}
impl EncodeKey for RoomId {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_str().as_bytes().into()
}
}
impl EncodeKey for TransactionId {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_str().as_bytes().into()
}
}
impl EncodeKey for MxcUri {
fn encode(&self) -> Vec<u8> {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
let s: &str = self.as_ref();
s.encode()
s.as_bytes().into()
}
}
impl EncodeKey for SecretName {
fn encode(&self) -> Vec<u8> {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
let s: &str = self.as_ref();
s.encode()
s.as_bytes().into()
}
}
impl EncodeKey for ReceiptType {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
let s: &str = self.as_ref();
s.as_bytes().into()
}
}
impl EncodeKey for EventEncryptionAlgorithm {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
let s: &str = self.as_ref();
s.as_bytes().into()
}
}
impl EncodeKey for RoomAccountDataEventType {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.to_string().as_bytes().to_vec().into()
}
}
impl EncodeKey for UserId {
fn encode(&self) -> Vec<u8> {
self.as_str().encode()
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.as_str().as_bytes().into()
}
}
impl EncodeKey for StateEventType {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.to_string().as_bytes().to_vec().into()
}
}
impl EncodeKey for GlobalAccountDataEventType {
fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
self.to_string().as_bytes().to_vec().into()
}
}
impl<A, B> EncodeKey for (A, B)
where
A: AsRef<str>,
B: AsRef<str>,
A: EncodeKey,
B: EncodeKey,
{
fn encode(&self) -> Vec<u8> {
[
self.0.as_ref().as_bytes(),
self.0.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
self.1.as_ref().as_bytes(),
self.1.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
]
.concat()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
[
store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
]
.concat()
@@ -94,17 +166,29 @@ where
impl<A, B, C> EncodeKey for (A, B, C)
where
A: AsRef<str>,
B: AsRef<str>,
C: AsRef<str>,
A: EncodeKey,
B: EncodeKey,
C: EncodeKey,
{
fn encode(&self) -> Vec<u8> {
[
self.0.as_ref().as_bytes(),
self.0.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
self.1.as_ref().as_bytes(),
self.1.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
self.2.as_ref().as_bytes(),
self.2.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
]
.concat()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
[
store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
store_cipher.hash_key(table_name, &self.2.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
]
.concat()
@@ -113,41 +197,36 @@ where
impl<A, B, C, D> EncodeKey for (A, B, C, D)
where
A: AsRef<str>,
B: AsRef<str>,
C: AsRef<str>,
D: AsRef<str>,
A: EncodeKey,
B: EncodeKey,
C: EncodeKey,
D: EncodeKey,
{
fn encode(&self) -> Vec<u8> {
[
self.0.as_ref().as_bytes(),
self.0.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
self.1.as_ref().as_bytes(),
self.1.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
self.2.as_ref().as_bytes(),
self.2.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
self.3.as_ref().as_bytes(),
self.3.encode_as_bytes().deref(),
&[ENCODE_SEPARATOR],
]
.concat()
}
fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
[
store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
store_cipher.hash_key(table_name, &self.2.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
store_cipher.hash_key(table_name, &self.3.encode_as_bytes()).as_slice(),
&[ENCODE_SEPARATOR],
]
.concat()
}
}
impl EncodeKey for StateEventType {
fn encode(&self) -> Vec<u8> {
self.to_string().encode()
}
}
impl EncodeKey for GlobalAccountDataEventType {
fn encode(&self) -> Vec<u8> {
self.to_string().encode()
}
}
#[cfg(feature = "state-store")]
pub fn encode_key_with_usize<A: AsRef<str>>(s: A, i: usize) -> Vec<u8> {
// FIXME: Not portable across architectures
[s.as_ref().as_bytes(), &[ENCODE_SEPARATOR], i.to_be_bytes().as_ref(), &[ENCODE_SEPARATOR]]
.concat()
}

View File

@@ -14,7 +14,6 @@
use std::{
collections::BTreeSet,
convert::TryInto,
path::{Path, PathBuf},
sync::Arc,
time::Instant,
@@ -23,15 +22,12 @@ use std::{
use anyhow::anyhow;
use async_stream::stream;
use futures_core::stream::Stream;
use futures_util::stream::{self, TryStreamExt};
use futures_util::stream::{self, StreamExt, TryStreamExt};
use matrix_sdk_base::{
deserialized_responses::{MemberEvent, SyncRoomEvent},
media::{MediaRequest, UniqueKey},
ruma::events::room::redaction::SyncRoomRedactionEvent,
store::{
store_key::{self, EncryptedEvent, StoreKey},
BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError,
},
store::{BoxStream, Result as StoreResult, StateChanges, StateStore, StoreError},
RoomInfo,
};
use matrix_sdk_common::{
@@ -52,6 +48,7 @@ use matrix_sdk_common::{
EventId, MxcUri, RoomId, RoomVersionId, UserId,
},
};
use matrix_sdk_store_encryption::{Error as KeyEncryptionError, StoreCipher};
use serde::{Deserialize, Serialize};
use sled::{
transaction::{ConflictableTransactionError, TransactionError},
@@ -62,22 +59,16 @@ use tracing::{info, warn};
#[cfg(feature = "crypto-store")]
use super::OpenStoreError;
use crate::encode_key::{encode_key_with_usize, EncodeKey, ENCODE_SEPARATOR};
use crate::encode_key::{EncodeKey, ENCODE_SEPARATOR};
#[cfg(feature = "crypto-store")]
pub use crate::CryptoStore;
#[derive(Debug, Serialize, Deserialize)]
pub enum DatabaseType {
Unencrypted,
Encrypted(store_key::EncryptedStoreKey),
}
#[derive(Debug, thiserror::Error)]
pub enum SledStoreError {
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Encryption(#[from] store_key::Error),
Encryption(#[from] KeyEncryptionError),
#[error(transparent)]
StoreError(#[from] StoreError),
#[error(transparent)]
@@ -104,9 +95,17 @@ impl Into<StoreError> for SledStoreError {
SledStoreError::Json(e) => StoreError::Json(e),
SledStoreError::Identifier(e) => StoreError::Identifier(e),
SledStoreError::Encryption(e) => match e {
store_key::Error::Random(e) => StoreError::Encryption(e.to_string()),
store_key::Error::Serialization(e) => StoreError::Json(e),
store_key::Error::Encryption(e) => StoreError::Encryption(e),
KeyEncryptionError::Random(e) => StoreError::Encryption(e.to_string()),
KeyEncryptionError::Serialization(e) => StoreError::Json(e),
KeyEncryptionError::Encryption(e) => StoreError::Encryption(e.to_string()),
KeyEncryptionError::Version(found, expected) => StoreError::Encryption(format!(
"Bad Database Encryption Version: expected {} found {}",
expected, found
)),
KeyEncryptionError::Length(found, expected) => StoreError::Encryption(format!(
"The database key an invalid length: expected {} found {}",
expected, found
)),
},
SledStoreError::StoreError(e) => e,
_ => StoreError::Backend(anyhow!(self)),
@@ -114,26 +113,36 @@ impl Into<StoreError> for SledStoreError {
}
}
const ACCOUNT_DATA: &str = "account-data";
const CUSTOM: &str = "custom";
const DISPLAY_NAME: &str = "display-name";
const INVITED_USER_ID: &str = "invited-user-id";
const JOINED_USER_ID: &str = "joined-user-id";
const MEDIA: &str = "media";
const MEMBER: &str = "member";
const PRESENCE: &str = "presence";
const PROFILE: &str = "profile";
const ROOM_ACCOUNT_DATA: &str = "room-account-data";
const ROOM_EVENT_ID_POSITION: &str = "room-event-id-to-position";
const ROOM_EVENT_RECEIPT: &str = "room-event-receipt";
const ROOM_INFO: &str = "room-info";
const ROOM_STATE: &str = "room-state";
const ROOM_USER_RECEIPT: &str = "room-user-receipt";
const ROOM: &str = "room";
const SESSION: &str = "session";
const STRIPPED_ROOM_INFO: &str = "stripped-room-info";
const STRIPPED_ROOM_MEMBER: &str = "stripped-room-member";
const STRIPPED_ROOM_STATE: &str = "stripped-room-state";
const TIMELINE_METADATA: &str = "timeline-metadata";
const TIMELINE: &str = "timeline";
type Result<A, E = SledStoreError> = std::result::Result<A, E>;
/// Get the value at `position` in encoded `key`.
///
/// The key must have been encoded with the `EncodeKey` trait. `position`
/// corresponds to the position in the tuple before the key was encoded. If it
/// wasn't encoded in a tuple, use `0`.
///
/// Returns `None` if there is no key at `position`.
pub fn decode_key_value(key: &[u8], position: usize) -> Option<String> {
let values: Vec<&[u8]> = key.split(|v| *v == ENCODE_SEPARATOR).collect();
values.get(position).map(|s| String::from_utf8_lossy(s).to_string())
}
#[derive(Clone)]
pub struct SledStore {
path: Option<PathBuf>,
pub(crate) inner: Db,
store_key: Arc<Option<StoreKey>>,
store_cipher: Arc<Option<StoreCipher>>,
session: Tree,
account_data: Tree,
members: Tree,
@@ -168,40 +177,44 @@ impl std::fmt::Debug for SledStore {
}
impl SledStore {
fn open_helper(db: Db, path: Option<PathBuf>, store_key: Option<StoreKey>) -> Result<Self> {
let session = db.open_tree("session")?;
let account_data = db.open_tree("account_data")?;
fn open_helper(
db: Db,
path: Option<PathBuf>,
store_cipher: Option<StoreCipher>,
) -> Result<Self> {
let session = db.open_tree(SESSION)?;
let account_data = db.open_tree(ACCOUNT_DATA)?;
let members = db.open_tree("members")?;
let profiles = db.open_tree("profiles")?;
let display_names = db.open_tree("display_names")?;
let joined_user_ids = db.open_tree("joined_user_ids")?;
let invited_user_ids = db.open_tree("invited_user_ids")?;
let members = db.open_tree(MEMBER)?;
let profiles = db.open_tree(PROFILE)?;
let display_names = db.open_tree(DISPLAY_NAME)?;
let joined_user_ids = db.open_tree(JOINED_USER_ID)?;
let invited_user_ids = db.open_tree(INVITED_USER_ID)?;
let room_state = db.open_tree("room_state")?;
let room_info = db.open_tree("room_infos")?;
let presence = db.open_tree("presence")?;
let room_account_data = db.open_tree("room_account_data")?;
let room_state = db.open_tree(ROOM_STATE)?;
let room_info = db.open_tree(ROOM_INFO)?;
let presence = db.open_tree(PRESENCE)?;
let room_account_data = db.open_tree(ROOM_ACCOUNT_DATA)?;
let stripped_room_infos = db.open_tree("stripped_room_infos")?;
let stripped_members = db.open_tree("stripped_members")?;
let stripped_room_state = db.open_tree("stripped_room_state")?;
let stripped_room_infos = db.open_tree(STRIPPED_ROOM_INFO)?;
let stripped_members = db.open_tree(STRIPPED_ROOM_MEMBER)?;
let stripped_room_state = db.open_tree(STRIPPED_ROOM_STATE)?;
let room_user_receipts = db.open_tree("room_user_receipts")?;
let room_event_receipts = db.open_tree("room_event_receipts")?;
let room_user_receipts = db.open_tree(ROOM_USER_RECEIPT)?;
let room_event_receipts = db.open_tree(ROOM_EVENT_RECEIPT)?;
let media = db.open_tree("media")?;
let media = db.open_tree(MEDIA)?;
let custom = db.open_tree("custom")?;
let custom = db.open_tree(CUSTOM)?;
let room_timeline = db.open_tree("room_timeline")?;
let room_timeline_metadata = db.open_tree("room_timeline_metadata")?;
let room_event_id_to_position = db.open_tree("room_event_id_to_position")?;
let room_timeline = db.open_tree(TIMELINE)?;
let room_timeline_metadata = db.open_tree(TIMELINE_METADATA)?;
let room_event_id_to_position = db.open_tree(ROOM_EVENT_ID_POSITION)?;
Ok(Self {
path,
inner: db,
store_key: store_key.into(),
store_cipher: store_cipher.into(),
session,
account_data,
members,
@@ -233,6 +246,20 @@ impl SledStore {
SledStore::open_helper(db, None, None).map_err(|e| e.into())
}
// testing only
#[cfg(test)]
fn open_encrypted() -> StoreResult<Self> {
let db =
Config::new().temporary(true).open().map_err(|e| StoreError::Backend(anyhow!(e)))?;
SledStore::open_helper(
db,
None,
Some(StoreCipher::new().expect("can't create store cipher")),
)
.map_err(|e| e.into())
}
pub fn open_with_passphrase(path: impl AsRef<Path>, passphrase: &str) -> StoreResult<Self> {
Self::inner_open_with_passphrase(path, passphrase).map_err(|e| e.into())
}
@@ -241,27 +268,15 @@ impl SledStore {
let path = path.as_ref().join("matrix-sdk-state");
let db = Config::new().temporary(false).path(&path).open()?;
let store_key: Option<DatabaseType> = db
.get("store_key".encode())?
.map(|k| serde_json::from_slice(&k).map_err(StoreError::Json))
.transpose()?;
let store_key = if let Some(key) = store_key {
if let DatabaseType::Encrypted(k) = key {
StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)?
} else {
return Err(SledStoreError::StoreError(StoreError::UnencryptedStore));
}
let store_cipher = if let Some(inner) = db.get("store_cipher".encode())? {
StoreCipher::import(passphrase, &inner)?
} else {
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
let encrypted_key = DatabaseType::Encrypted(
key.export(passphrase).map_err::<StoreError, _>(|e| e.into())?,
);
db.insert("store_key".encode(), serde_json::to_vec(&encrypted_key)?)?;
key
let cipher = StoreCipher::new()?;
db.insert("store_cipher".encode(), cipher.export(passphrase)?)?;
cipher
};
SledStore::open_helper(db, Some(path), Some(store_key))
SledStore::open_helper(db, Some(path), Some(store_cipher))
}
pub fn open_with_path(path: impl AsRef<Path>) -> StoreResult<Self> {
@@ -287,9 +302,8 @@ impl SledStore {
}
fn serialize_event(&self, event: &impl Serialize) -> Result<Vec<u8>, SledStoreError> {
if let Some(key) = &*self.store_key {
let encrypted = key.encrypt(event)?;
Ok(serde_json::to_vec(&encrypted)?)
if let Some(key) = &*self.store_cipher {
Ok(key.encrypt_value(event)?)
} else {
Ok(serde_json::to_vec(event)?)
}
@@ -299,24 +313,40 @@ impl SledStore {
&self,
event: &[u8],
) -> Result<T, SledStoreError> {
if let Some(key) = &*self.store_key {
let encrypted: EncryptedEvent = serde_json::from_slice(event)?;
Ok(key.decrypt(encrypted)?)
if let Some(key) = &*self.store_cipher {
Ok(key.decrypt_value(event)?)
} else {
Ok(serde_json::from_slice(event)?)
}
}
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session.insert(("filter", filter_name).encode(), filter_id)?;
fn encode_key<T: EncodeKey>(&self, table_name: &str, key: T) -> Vec<u8> {
if let Some(store_cipher) = &*self.store_cipher {
key.encode_secure(table_name, store_cipher).to_vec()
} else {
key.encode()
}
}
fn encode_key_with_counter<A: EncodeKey>(&self, tablename: &str, s: A, i: usize) -> Vec<u8> {
[
&self.encode_key(tablename, s),
[ENCODE_SEPARATOR].as_slice(),
i.to_be_bytes().as_ref(),
[ENCODE_SEPARATOR].as_slice(),
]
.concat()
}
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.session.insert(self.encode_key(SESSION, ("filter", filter_name)), filter_id)?;
Ok(())
}
pub async fn get_filter(&self, filter_name: &str) -> Result<Option<String>> {
Ok(self
.session
.get(("filter", filter_name).encode())?
.get(self.encode_key(SESSION, ("filter", filter_name)))?
.map(|f| String::from_utf8_lossy(&f).to_string()))
}
@@ -371,25 +401,31 @@ impl SledStore {
let profile_changes = changes.profiles.get(room);
for event in events.values() {
let key = (room, &event.state_key).encode();
let key = (room, &event.state_key);
match event.content.membership {
MembershipState::Join => {
joined.insert(key.as_slice(), event.state_key.as_str())?;
invited.remove(key.as_slice())?;
joined.insert(
self.encode_key(JOINED_USER_ID, &key),
event.state_key.as_str(),
)?;
invited.remove(self.encode_key(INVITED_USER_ID, &key))?;
}
MembershipState::Invite => {
invited.insert(key.as_slice(), event.state_key.as_str())?;
joined.remove(key.as_slice())?;
invited.insert(
self.encode_key(INVITED_USER_ID, &key),
event.state_key.as_str(),
)?;
joined.remove(self.encode_key(JOINED_USER_ID, &key))?;
}
_ => {
joined.remove(key.as_slice())?;
invited.remove(key.as_slice())?;
joined.remove(self.encode_key(JOINED_USER_ID, &key))?;
invited.remove(self.encode_key(INVITED_USER_ID, &key))?;
}
}
members.insert(
key.as_slice(),
self.encode_key(MEMBER, &key),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -398,7 +434,7 @@ impl SledStore {
profile_changes.and_then(|p| p.get(&event.state_key))
{
profiles.insert(
key.as_slice(),
self.encode_key(PROFILE, &key),
self.serialize_event(&profile)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -409,7 +445,7 @@ impl SledStore {
for (room_id, ambiguity_maps) in &changes.ambiguity_maps {
for (display_name, map) in ambiguity_maps {
display_names.insert(
(room_id, display_name).encode(),
self.encode_key(DISPLAY_NAME, (room_id, display_name)),
self.serialize_event(&map)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -418,7 +454,7 @@ impl SledStore {
for (event_type, event) in &changes.account_data {
account_data.insert(
event_type.encode(),
self.encode_key(ACCOUNT_DATA, event_type),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -427,7 +463,7 @@ impl SledStore {
for (room, events) in &changes.room_account_data {
for (event_type, event) in events {
room_account_data.insert(
(room, event_type.to_string()).encode(),
self.encode_key(ROOM_ACCOUNT_DATA, &(room, event_type)),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -438,7 +474,7 @@ impl SledStore {
for (event_type, events) in event_types {
for (state_key, event) in events {
state.insert(
(room, event_type.to_string(), state_key).encode(),
self.encode_key(ROOM_STATE, (room, event_type, state_key)),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -448,7 +484,7 @@ impl SledStore {
for (room_id, room_info) in &changes.room_infos {
rooms.insert(
room_id.encode(),
self.encode_key(ROOM, room_id),
self.serialize_event(room_info)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -456,7 +492,7 @@ impl SledStore {
for (sender, event) in &changes.presence {
presence.insert(
sender.encode(),
self.encode_key(PRESENCE, sender),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -464,7 +500,7 @@ impl SledStore {
for (room_id, info) in &changes.stripped_room_infos {
striped_rooms.insert(
room_id.encode(),
self.encode_key(STRIPPED_ROOM_INFO, room_id),
self.serialize_event(&info)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -473,7 +509,10 @@ impl SledStore {
for (room, events) in &changes.stripped_members {
for event in events.values() {
stripped_members.insert(
(room, event.state_key.to_string()).encode(),
self.encode_key(
STRIPPED_ROOM_MEMBER,
(room, event.state_key.to_string()),
),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -484,7 +523,10 @@ impl SledStore {
for (event_type, events) in event_types {
for (state_key, event) in events {
stripped_state.insert(
(room, event_type.to_string(), state_key).encode(),
self.encode_key(
STRIPPED_ROOM_STATE,
(room, event_type.to_string(), state_key),
),
self.serialize_event(&event)
.map_err(ConflictableTransactionError::Abort)?,
)?;
@@ -507,7 +549,10 @@ impl SledStore {
for (user_id, receipt) in receipts {
// Add the receipt to the room user receipts
if let Some(old) = room_user_receipts.insert(
(room, receipt_type, user_id).encode(),
self.encode_key(
ROOM_USER_RECEIPT,
(room, receipt_type, user_id),
),
self.serialize_event(&(event_id, receipt))
.map_err(ConflictableTransactionError::Abort)?,
)? {
@@ -515,15 +560,19 @@ impl SledStore {
let (old_event, _): (Box<EventId>, Receipt) = self
.deserialize_event(&old)
.map_err(ConflictableTransactionError::Abort)?;
room_event_receipts.remove(
(room, receipt_type, old_event, user_id).encode(),
)?;
room_event_receipts.remove(self.encode_key(
ROOM_EVENT_RECEIPT,
(room, receipt_type, old_event, user_id),
))?;
}
// Add the receipt to the room event receipts
room_event_receipts.insert(
(room, receipt_type, event_id, user_id).encode(),
self.serialize_event(receipt)
self.encode_key(
ROOM_EVENT_RECEIPT,
(room, receipt_type, event_id, user_id),
),
self.serialize_event(&(user_id, receipt))
.map_err(ConflictableTransactionError::Abort)?,
)?;
}
@@ -548,7 +597,7 @@ impl SledStore {
pub async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
let db = self.clone();
let key = user_id.encode();
let key = self.encode_key(PRESENCE, user_id);
spawn_blocking(move || db.presence.get(key)?.map(|e| db.deserialize_event(&e)).transpose())
.await?
}
@@ -560,7 +609,7 @@ impl SledStore {
state_key: &str,
) -> Result<Option<Raw<AnySyncStateEvent>>> {
let db = self.clone();
let key = (room_id, event_type.to_string(), state_key).encode();
let key = self.encode_key(ROOM_STATE, (room_id, event_type.to_string(), state_key));
spawn_blocking(move || {
db.room_state.get(key)?.map(|e| db.deserialize_event(&e)).transpose()
})
@@ -573,7 +622,7 @@ impl SledStore {
event_type: StateEventType,
) -> Result<Vec<Raw<AnySyncStateEvent>>> {
let db = self.clone();
let key = (room_id, event_type.to_string()).encode();
let key = self.encode_key(ROOM_STATE, (room_id, event_type.to_string()));
spawn_blocking(move || {
db.room_state
.scan_prefix(key)
@@ -589,7 +638,7 @@ impl SledStore {
user_id: &UserId,
) -> Result<Option<RoomMemberEventContent>> {
let db = self.clone();
let key = (room_id, user_id).encode();
let key = self.encode_key(PROFILE, (room_id, user_id));
spawn_blocking(move || db.profiles.get(key)?.map(|p| db.deserialize_event(&p)).transpose())
.await?
}
@@ -600,7 +649,7 @@ impl SledStore {
state_key: &UserId,
) -> Result<Option<MemberEvent>> {
let db = self.clone();
let key = (room_id, state_key).encode();
let key = self.encode_key(MEMBER, (room_id, state_key));
spawn_blocking(move || db.members.get(key)?.map(|v| db.deserialize_event(&v)).transpose())
.await?
}
@@ -609,29 +658,10 @@ impl SledStore {
&self,
room_id: &RoomId,
) -> StoreResult<impl Stream<Item = StoreResult<Box<UserId>>>> {
let decode = |key: &[u8]| -> StoreResult<Box<UserId>> {
let mut iter = key.split(|c| c == &ENCODE_SEPARATOR);
// Our key is a the room id separated from the user id by a null
// byte, discard the first value of the split.
iter.next();
let user_id = iter.next().expect("User ids weren't properly encoded");
Ok(UserId::parse(String::from_utf8_lossy(user_id).to_string())?)
};
let members = self.members.clone();
let key = room_id.encode();
spawn_blocking(move || {
stream::iter(
members
.scan_prefix(key)
.map(move |u| decode(&u.map_err(|e| StoreError::Backend(anyhow!(e)))?.0)),
)
})
.await
.map_err(|e| StoreError::Backend(anyhow!(e)))
Ok(self
.get_joined_user_ids(room_id)
.await?
.chain(self.get_invited_user_ids(room_id).await?))
}
pub async fn get_invited_user_ids(
@@ -639,7 +669,7 @@ impl SledStore {
room_id: &RoomId,
) -> StoreResult<impl Stream<Item = StoreResult<Box<UserId>>>> {
let db = self.clone();
let key = room_id.encode();
let key = self.encode_key(INVITED_USER_ID, room_id);
spawn_blocking(move || {
stream::iter(db.invited_user_ids.scan_prefix(key).map(|u| {
UserId::parse(
@@ -658,7 +688,7 @@ impl SledStore {
room_id: &RoomId,
) -> StoreResult<impl Stream<Item = StoreResult<Box<UserId>>>> {
let db = self.clone();
let key = room_id.encode();
let key = self.encode_key(JOINED_USER_ID, room_id);
spawn_blocking(move || {
stream::iter(db.joined_user_ids.scan_prefix(key).map(|u| {
UserId::parse(
@@ -696,7 +726,7 @@ impl SledStore {
display_name: &str,
) -> Result<BTreeSet<Box<UserId>>> {
let db = self.clone();
let key = (room_id, display_name).encode();
let key = self.encode_key(DISPLAY_NAME, (room_id, display_name));
spawn_blocking(move || {
Ok(db
.display_names
@@ -713,7 +743,7 @@ impl SledStore {
event_type: GlobalAccountDataEventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
let db = self.clone();
let key = event_type.encode();
let key = self.encode_key(ACCOUNT_DATA, event_type);
spawn_blocking(move || {
db.account_data.get(key)?.map(|m| db.deserialize_event(&m)).transpose()
})
@@ -726,7 +756,7 @@ impl SledStore {
event_type: RoomAccountDataEventType,
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
let db = self.clone();
let key = (room_id, event_type.to_string()).encode();
let key = self.encode_key(ROOM_ACCOUNT_DATA, (room_id, event_type));
spawn_blocking(move || {
db.room_account_data.get(key)?.map(|m| db.deserialize_event(&m)).transpose()
})
@@ -740,7 +770,7 @@ impl SledStore {
user_id: &UserId,
) -> Result<Option<(Box<EventId>, Receipt)>> {
let db = self.clone();
let key = (room_id, receipt_type, user_id).encode();
let key = self.encode_key(ROOM_USER_RECEIPT, (room_id, receipt_type, user_id));
spawn_blocking(move || {
db.room_user_receipts.get(key)?.map(|m| db.deserialize_event(&m)).transpose()
})
@@ -754,19 +784,14 @@ impl SledStore {
event_id: &EventId,
) -> StoreResult<Vec<(Box<UserId>, Receipt)>> {
let db = self.clone();
let key = (room_id, receipt_type, event_id).encode();
let key = self.encode_key(ROOM_EVENT_RECEIPT, (room_id, receipt_type, event_id));
spawn_blocking(move || {
db.room_event_receipts
.scan_prefix(key)
.values()
.map(|u| {
u.map_err(|e| StoreError::Backend(anyhow!(e))).and_then(|(key, value)| {
db.deserialize_event(&value)
// TODO remove this unwrapping
.map(|receipt| {
(decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt)
})
.map_err(|e| StoreError::Backend(anyhow!(e)))
})
let v = u.map_err(|e| StoreError::Backend(anyhow!(e)))?;
db.deserialize_event(&v).map_err(|e| StoreError::Backend(anyhow!(e)))
})
.collect()
})
@@ -775,8 +800,10 @@ impl SledStore {
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.media
.insert((request.source.unique_key(), request.format.unique_key()).encode(), data)?;
self.media.insert(
self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())),
data,
)?;
self.inner.flush_async().await?;
@@ -785,7 +812,8 @@ impl SledStore {
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
let db = self.clone();
let key = (request.source.unique_key(), request.format.unique_key()).encode();
let key =
self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key()));
spawn_blocking(move || Ok(db.media.get(key)?.map(|m| m.to_vec()))).await?
}
@@ -804,13 +832,15 @@ impl SledStore {
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.media.remove((request.source.unique_key(), request.format.unique_key()).encode())?;
self.media.remove(
self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())),
)?;
Ok(())
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let keys = self.media.scan_prefix(uri.encode()).keys();
let keys = self.media.scan_prefix(self.encode_key(MEDIA, uri)).keys();
let mut batch = sled::Batch::default();
for key in keys {
@@ -821,60 +851,75 @@ impl SledStore {
}
async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
let room_key = room_id.encode();
let mut members_batch = sled::Batch::default();
for key in self.members.scan_prefix(room_key.as_slice()).keys() {
for key in self.members.scan_prefix(self.encode_key(MEMBER, room_id)).keys() {
members_batch.remove(key?)
}
let mut profiles_batch = sled::Batch::default();
for key in self.profiles.scan_prefix(room_key.as_slice()).keys() {
for key in self.profiles.scan_prefix(self.encode_key(PROFILE, room_id)).keys() {
profiles_batch.remove(key?)
}
let mut display_names_batch = sled::Batch::default();
for key in self.display_names.scan_prefix(room_key.as_slice()).keys() {
for key in self.display_names.scan_prefix(self.encode_key(DISPLAY_NAME, room_id)).keys() {
display_names_batch.remove(key?)
}
let mut joined_user_ids_batch = sled::Batch::default();
for key in self.joined_user_ids.scan_prefix(room_key.as_slice()).keys() {
for key in self.joined_user_ids.scan_prefix(self.encode_key(JOINED_USER_ID, room_id)).keys()
{
joined_user_ids_batch.remove(key?)
}
let mut invited_user_ids_batch = sled::Batch::default();
for key in self.invited_user_ids.scan_prefix(room_key.as_slice()).keys() {
for key in
self.invited_user_ids.scan_prefix(self.encode_key(INVITED_USER_ID, room_id)).keys()
{
invited_user_ids_batch.remove(key?)
}
let mut room_state_batch = sled::Batch::default();
for key in self.room_state.scan_prefix(room_key.as_slice()).keys() {
for key in self.room_state.scan_prefix(self.encode_key(ROOM_STATE, room_id)).keys() {
room_state_batch.remove(key?)
}
let mut room_account_data_batch = sled::Batch::default();
for key in self.room_account_data.scan_prefix(room_key.as_slice()).keys() {
for key in
self.room_account_data.scan_prefix(self.encode_key(ROOM_ACCOUNT_DATA, room_id)).keys()
{
room_account_data_batch.remove(key?)
}
let mut stripped_members_batch = sled::Batch::default();
for key in self.stripped_members.scan_prefix(room_key.as_slice()).keys() {
for key in
self.stripped_members.scan_prefix(self.encode_key(STRIPPED_ROOM_MEMBER, room_id)).keys()
{
stripped_members_batch.remove(key?)
}
let mut stripped_room_state_batch = sled::Batch::default();
for key in self.stripped_room_state.scan_prefix(room_key.as_slice()).keys() {
for key in self
.stripped_room_state
.scan_prefix(self.encode_key(STRIPPED_ROOM_STATE, room_id))
.keys()
{
stripped_room_state_batch.remove(key?)
}
let mut room_user_receipts_batch = sled::Batch::default();
for key in self.room_user_receipts.scan_prefix(room_key.as_slice()).keys() {
for key in
self.room_user_receipts.scan_prefix(self.encode_key(ROOM_USER_RECEIPT, room_id)).keys()
{
room_user_receipts_batch.remove(key?)
}
let mut room_event_receipts_batch = sled::Batch::default();
for key in self.room_event_receipts.scan_prefix(room_key.as_slice()).keys() {
for key in self
.room_event_receipts
.scan_prefix(self.encode_key(ROOM_EVENT_RECEIPT, room_id))
.keys()
{
room_event_receipts_batch.remove(key?)
}
@@ -909,8 +954,8 @@ impl SledStore {
room_user_receipts,
room_event_receipts,
)| {
rooms.remove(room_key.as_slice())?;
stripped_rooms.remove(room_key.as_slice())?;
rooms.remove(self.encode_key(ROOM, room_id))?;
stripped_rooms.remove(self.encode_key(STRIPPED_ROOM_INFO, room_id))?;
members.apply_batch(&members_batch)?;
profiles.apply_batch(&profiles_batch)?;
@@ -942,7 +987,7 @@ impl SledStore {
room_id: &RoomId,
) -> Result<Option<(BoxStream<StoreResult<SyncRoomEvent>>, Option<String>)>> {
let db = self.clone();
let key = room_id.encode();
let key = self.encode_key(TIMELINE_METADATA, room_id);
let r_id = room_id.to_owned();
let metadata: Option<TimelineMetadata> = db
.room_timeline_metadata
@@ -963,7 +1008,7 @@ impl SledStore {
info!("Found previously stored timeline for {}, with end token {:?}", r_id, end_token);
let stream = stream! {
while let Ok(Some(item)) = db.room_timeline.get(&encode_key_with_usize(&r_id, position)) {
while let Ok(Some(item)) = db.room_timeline.get(&db.encode_key_with_counter(TIMELINE, &r_id, position)) {
position += 1;
yield db.deserialize_event(&item).map_err(SledStoreError::from).map_err(|e| e.into());
}
@@ -973,16 +1018,19 @@ impl SledStore {
}
async fn remove_room_timeline(&self, room_id: &RoomId) -> Result<()> {
let room_key = room_id.encode();
info!("Remove stored timeline for {}", room_id);
let mut timeline_batch = sled::Batch::default();
for key in self.room_timeline.scan_prefix(room_key.as_slice()).keys() {
for key in self.room_timeline.scan_prefix(self.encode_key(TIMELINE, &room_id)).keys() {
timeline_batch.remove(key?)
}
let mut event_id_to_position_batch = sled::Batch::default();
for key in self.room_event_id_to_position.scan_prefix(room_key.as_slice()).keys() {
for key in self
.room_event_id_to_position
.scan_prefix(self.encode_key(ROOM_EVENT_ID_POSITION, &room_id))
.keys()
{
event_id_to_position_batch.remove(key?)
}
@@ -990,7 +1038,8 @@ impl SledStore {
(&self.room_timeline, &self.room_timeline_metadata, &self.room_event_id_to_position)
.transaction(
|(room_timeline, room_timeline_metadata, room_event_id_to_position)| {
room_timeline_metadata.remove(room_key.as_slice())?;
room_timeline_metadata
.remove(self.encode_key(TIMELINE_METADATA, &room_id))?;
room_timeline.apply_batch(&timeline_batch)?;
room_event_id_to_position.apply_batch(&event_id_to_position_batch)?;
@@ -1015,7 +1064,6 @@ impl SledStore {
} else {
info!("Save new timeline batch from messages response for {}", room_id);
}
let room_key = room_id.encode();
let metadata: Option<TimelineMetadata> = if timeline.limited {
info!(
@@ -1027,7 +1075,7 @@ impl SledStore {
} else {
let metadata: Option<TimelineMetadata> = self
.room_timeline_metadata
.get(room_key.as_slice())?
.get(self.encode_key(TIMELINE_METADATA, &room_id))?
.map(|v| serde_json::from_slice(&v).map_err(StoreError::Json))
.transpose()?;
if let Some(mut metadata) = metadata {
@@ -1043,7 +1091,8 @@ impl SledStore {
let mut delete_timeline = false;
for event in &timeline.events {
if let Some(event_id) = event.event_id() {
let event_key = (room_id, event_id).encode();
let event_key =
self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, event_id));
if self.room_event_id_to_position.contains_key(event_key)? {
delete_timeline = true;
break;
@@ -1082,7 +1131,7 @@ impl SledStore {
};
let room_version = self
.room_info
.get(room_id.encode())?
.get(&self.encode_key(ROOM_INFO, room_id))?
.map(|r| self.deserialize_event::<RoomInfo>(&r))
.transpose()?
.and_then(|info| info.room_version().cloned())
@@ -1100,7 +1149,8 @@ impl SledStore {
)),
)) = event.event.deserialize()
{
let redacts_key = (room_id, redaction.redacts).encode();
let redacts_key =
self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, redaction.redacts));
if let Some(position_key) =
self.room_event_id_to_position.get(redacts_key)?
{
@@ -1125,28 +1175,35 @@ impl SledStore {
}
metadata.start_position -= 1;
let key = encode_key_with_usize(room_id, metadata.start_position);
let key =
self.encode_key_with_counter(TIMELINE, room_id, metadata.start_position);
timeline_batch.insert(key.as_slice(), self.serialize_event(&event)?);
// Only add event with id to the position map
if let Some(event_id) = event.event_id() {
let event_key = (room_id, event_id).encode();
let event_key =
self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, event_id));
event_id_to_position_batch.insert(event_key.as_slice(), key.as_slice());
}
}
} else {
for event in &timeline.events {
metadata.end_position += 1;
let key = encode_key_with_usize(room_id, metadata.end_position);
let key =
self.encode_key_with_counter(TIMELINE, room_id, metadata.end_position);
timeline_batch.insert(key.as_slice(), self.serialize_event(&event)?);
// Only add event with id to the position map
if let Some(event_id) = event.event_id() {
let event_key = (room_id, event_id).encode();
let event_key =
self.encode_key(ROOM_EVENT_ID_POSITION, (room_id, event_id));
event_id_to_position_batch.insert(event_key.as_slice(), key.as_slice());
}
}
}
timeline_metadata_batch.insert(room_key, serde_json::to_vec(&metadata)?);
timeline_metadata_batch.insert(
self.encode_key(TIMELINE_METADATA, &room_id),
serde_json::to_vec(&metadata)?,
);
}
let ret: Result<(), TransactionError<SledStoreError>> =
@@ -1355,3 +1412,16 @@ mod tests {
statestore_integration_tests! { integration }
}
#[cfg(test)]
mod encrypted_tests {
use matrix_sdk_base::statestore_integration_tests;
use super::{SledStore, StateStore, StoreResult};
async fn get_store() -> StoreResult<impl StateStore> {
SledStore::open_encrypted().map_err(Into::into)
}
statestore_integration_tests! { integration }
}

View File

@@ -284,16 +284,83 @@ impl StoreCipher {
/// # Result::<_, anyhow::Error>::Ok(()) };
/// ```
pub fn encrypt_value(&self, value: &impl Serialize) -> Result<Vec<u8>, Error> {
let mut data = serde_json::to_vec(value)?;
Ok(serde_json::to_vec(&self.encrypt_value_typed(value)?)?)
}
/// Encrypt a value before it is inserted into the key/value store.
///
/// A value can be decrypted using the
/// [`StoreCipher::decrypt_value_typed()`] method. This is the lower
/// level function to `encrypt_value`, but returns the
/// full `EncryptdValue`-type
///
/// # Arguments
///
/// * `value` - A value that should be encrypted, any value that implements
/// `Serialize` can be given to this method. The value will be serialized as
/// json before it is encrypted.
///
///
/// # Examples
///
/// ```no_run
/// # let example = || {
/// use matrix_sdk_store_encryption::StoreCipher;
/// use serde_json::{json, value::Value};
///
/// let store_cipher = StoreCipher::new()?;
///
/// let value = json!({
/// "some": "data",
/// });
///
/// let encrypted = store_cipher.encrypt_value_typed(&value)?;
/// let decrypted: Value = store_cipher.decrypt_value_typed(encrypted)?;
///
/// assert_eq!(value, decrypted);
/// # anyhow::Ok(()) };
/// ```
pub fn encrypt_value_typed(&self, value: &impl Serialize) -> Result<EncryptedValue, Error> {
let data = serde_json::to_vec(value)?;
self.encrypt_value_data(data)
}
/// Encrypt some data before it is inserted into the key/value store.
///
/// A value can be decrypted using the [`StoreCipher::decrypt_value_data()`]
/// method. This is the lower level function to `encrypt_value`
///
/// # Arguments
///
/// * `data` - A value that should be encrypted, encoded as a `Vec<u8>`
///
/// # Examples
///
/// ```
/// # let example = || {
/// use matrix_sdk_store_encryption::StoreCipher;
/// use serde_json::{json, value::Value};
///
/// let store_cipher = StoreCipher::new()?;
///
/// let value = serde_json::to_vec(&json!({
/// "some": "data",
/// }))?;
///
/// let encrypted = store_cipher.encrypt_value_data(value.clone())?;
/// let decrypted = store_cipher.decrypt_value_data(encrypted)?;
///
/// assert_eq!(value, decrypted);
/// # Result::<_, anyhow::Error>::Ok(()) };
/// ```
pub fn encrypt_value_data(&self, mut data: Vec<u8>) -> Result<EncryptedValue, Error> {
let nonce = Keys::get_nonce()?;
let cipher = XChaCha20Poly1305::new(self.inner.encryption_key());
let ciphertext = cipher.encrypt(XNonce::from_slice(&nonce), data.as_ref())?;
data.zeroize();
Ok(serde_json::to_vec(&EncryptedValue { version: VERSION, ciphertext, nonce })?)
Ok(EncryptedValue { version: VERSION, ciphertext, nonce })
}
/// Decrypt a value after it was fetchetd from the key/value store.
@@ -325,22 +392,91 @@ impl StoreCipher {
///
/// assert_eq!(value, decrypted);
/// # Result::<_, anyhow::Error>::Ok(()) };
/// ```
pub fn decrypt_value<T: for<'b> Deserialize<'b>>(&self, value: &[u8]) -> Result<T, Error> {
let value: EncryptedValue = serde_json::from_slice(value)?;
self.decrypt_value_typed(value)
}
/// Decrypt a value after it was fetchetd from the key/value store.
///
/// A value can be encrypted using the
/// [`StoreCipher::encrypt_value_typed()`] method. Lower level method to
/// [`StoreCipher::decrypt_value_typed()`]
///
/// # Arguments
///
/// * `value` - The EncryptedValue of a value that should be decrypted.
///
/// The method will deserialize the decrypted value into the expected type.
///
/// # Examples
///
/// ```
/// # let example = || {
/// use matrix_sdk_store_encryption::StoreCipher;
/// use serde_json::{json, value::Value};
///
/// let store_cipher = StoreCipher::new()?;
///
/// let value = json!({
/// "some": "data",
/// });
///
/// let encrypted = store_cipher.encrypt_value_typed(&value)?;
/// let decrypted: Value = store_cipher.decrypt_value_typed(encrypted)?;
///
/// assert_eq!(value, decrypted);
/// # Result::<_, anyhow::Error>::Ok(()) };
/// ```
pub fn decrypt_value_typed<T: for<'b> Deserialize<'b>>(
&self,
value: EncryptedValue,
) -> Result<T, Error> {
let mut plaintext = self.decrypt_value_data(value)?;
let ret = serde_json::from_slice(&plaintext);
plaintext.zeroize();
Ok(ret?)
}
/// Decrypt a value after it was fetchetd from the key/value store.
///
/// A value can be encrypted using the [`StoreCipher::encrypt_value_data()`]
/// method. Lower level method to [`StoreCipher::decrypt_value()`].
///
/// # Arguments
///
/// * `value` - The EncryptedValue of a value that should be decrypted.
///
/// The method will return the raw decrypted value
///
/// # Examples
///
/// ```
/// # let example = || {
/// use matrix_sdk_store_encryption::StoreCipher;
/// use serde_json::{json, value::Value};
///
/// let store_cipher = StoreCipher::new()?;
///
/// let value = serde_json::to_vec(&json!({
/// "some": "data",
/// }))?;
///
/// let encrypted = store_cipher.encrypt_value_data(value.clone())?;
/// let decrypted = store_cipher.decrypt_value_data(encrypted)?;
///
/// assert_eq!(value, decrypted);
/// # Result::<_, anyhow::Error>::Ok(()) };
/// ```
pub fn decrypt_value_data(&self, value: EncryptedValue) -> Result<Vec<u8>, Error> {
if value.version != VERSION {
return Err(Error::Version(VERSION, value.version));
}
let cipher = XChaCha20Poly1305::new(self.inner.encryption_key());
let nonce = XNonce::from_slice(&value.nonce);
let mut plaintext = cipher.decrypt(nonce, value.ciphertext.as_ref())?;
let ret = serde_json::from_slice(&plaintext);
plaintext.zeroize();
Ok(ret?)
Ok(cipher.decrypt(nonce, value.ciphertext.as_ref())?)
}
/// Expand the given passphrase into a KEY_SIZE long key.
@@ -362,8 +498,10 @@ impl MacKey {
}
}
/// Encrypted value, ready for storage, as created by the
/// [`StoreCipher::encrypt_value_data()`]
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct EncryptedValue {
pub struct EncryptedValue {
version: u8,
ciphertext: Vec<u8>,
nonce: [u8; XNONCE_SIZE],