moving sled state store into separate crate

This commit is contained in:
Benjamin Kampmann
2022-02-17 19:02:17 +01:00
parent ea92bc676c
commit 0ca6ec1377
6 changed files with 146 additions and 117 deletions

View File

@@ -19,11 +19,7 @@ rustdoc-args = ["--cfg", "docsrs"]
default = []
encryption = ["matrix-sdk-crypto"]
qrcode = ["matrix-sdk-crypto/qrcode"]
sled_state_store = [
"sled",
"tokio",
"state_key",
]
sled_state_store = [ ]
sled_cryptostore = ["matrix-sdk-crypto/sled_cryptostore"]
indexeddb_state_store = ["state_key"]
@@ -54,15 +50,11 @@ rand = { version = "0.8.4", optional = true }
serde = { version = "1.0.126", features = ["rc"] }
serde_json = "1.0.64"
sha2 = { version = "0.10.1", optional = true }
sled = { version = "0.34.6", optional = true }
thiserror = "1.0.25"
tracing = "0.1.26"
zeroize = { version = "1.3.0", features = ["zeroize_derive"] }
anyhow = "1"
## Feature sled_state_store
tokio = { version = "1.7.1", optional = true, default-features = false, features = ["sync", "fs"] }
[dependencies.ruma]
git = "https://github.com/ruma/ruma/"
rev = "b9f32bc6327542d382d4eb42ec43623495c50e66"

View File

@@ -36,7 +36,7 @@ macro_rules! statestore_integration_tests {
store::{
Store,
StateStore,
Result,
Result as StoreResult,
StateChanges
}
};
@@ -64,7 +64,7 @@ macro_rules! statestore_integration_tests {
}
/// Populate the given `StateStore`.
pub(crate) async fn populated_store(inner: Box<dyn StateStore>) -> Result<Store> {
pub(crate) async fn populated_store(inner: Box<dyn StateStore>) -> StoreResult<Store> {
let mut changes = StateChanges::default();
let store = Store::new(inner);
@@ -231,7 +231,7 @@ macro_rules! statestore_integration_tests {
}
#[async_test]
async fn test_populate_store() -> Result<()> {
async fn test_populate_store() -> StoreResult<()> {
let room_id = room_id();
let user_id = user_id();
let inner_store = get_store().await?;
@@ -449,7 +449,7 @@ macro_rules! statestore_integration_tests {
}
#[async_test]
async fn test_custom_storage() -> Result<()> {
async fn test_custom_storage() -> StoreResult<()> {
let key = "my_key";
let value = &[0, 1, 2, 3];
let store = get_store().await?;
@@ -464,7 +464,7 @@ macro_rules! statestore_integration_tests {
}
#[async_test]
async fn test_persist_invited_room() -> Result<()> {
async fn test_persist_invited_room() -> StoreResult<()> {
let stripped_room_id = stripped_room_id();
let inner_store = get_store().await?;
let store = populated_store(Box::new(inner_store)).await?;
@@ -477,7 +477,7 @@ macro_rules! statestore_integration_tests {
}
#[async_test]
async fn test_room_removal() -> Result<()> {
async fn test_room_removal() -> StoreResult<()> {
let room_id = room_id();
let user_id = user_id();
let inner_store = get_store().await?;

View File

@@ -57,13 +57,9 @@ use crate::{
pub(crate) mod ambiguity_map;
mod memory_store;
#[cfg(feature = "sled_state_store")]
mod sled_store;
#[cfg(not(any(feature = "sled_state_store", feature = "indexeddb_state_store")))]
use self::memory_store::MemoryStore;
#[cfg(feature = "sled_state_store")]
use self::sled_store::SledStore;
/// State store specific error type.
#[derive(Debug, thiserror::Error)]
@@ -91,10 +87,6 @@ pub enum StoreError {
/// The store failed to encode or decode some data.
#[error("Error encoding or decoding data from the store: {0}")]
Codec(String),
/// An error happened while running a tokio task.
#[cfg(feature = "sled_state_store")]
#[error(transparent)]
Task(#[from] tokio::task::JoinError),
}
/// A `StateStore` specific result type.
pub type Result<T, E = StoreError> = std::result::Result<T, E>;

View File

@@ -0,0 +1,28 @@
[package]
name = "matrix-sdk-sled"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
futures-core = "0.3.15"
futures-util = { version = "0.3.15", default-features = false }
matrix-sdk-base = { path = "../matrix-sdk-base", features = ["state_key"] }
matrix-sdk-common = { path = "../matrix-sdk-common" }
matrix-sdk-crypto = { path = "../matrix-sdk-crypto" }
serde = "1"
serde_json = "1.0.64"
sled = { version = "0.34.6" }
thiserror = "1.0.25"
tokio = { version = "1.7.1", default-features = false, features = ["sync", "fs"] }
tracing = "0.1.26"
anyhow = "1"
[dev-dependencies]
matrix-sdk-test = { version = "0.4.0", path = "../matrix-sdk-test" }
matrix-sdk-base = { path = "../matrix-sdk-base", features = ["testing"] }
tokio = { version = "1.7.1", default-features = false, features = [
"rt-multi-thread",
"macros",
] }

View File

@@ -0,0 +1 @@
pub mod state_store;

View File

@@ -23,7 +23,8 @@ use std::{
use futures_core::stream::Stream;
use futures_util::stream::{self, TryStreamExt};
use matrix_sdk_common::async_trait;
use ruma::{
use matrix_sdk_common::ruma::{
self,
events::{
presence::PresenceEvent,
receipt::Receipt,
@@ -41,10 +42,13 @@ use sled::{
};
use tokio::task::spawn_blocking;
use tracing::info;
use anyhow::anyhow;
use self::store_key::{EncryptedEvent, StoreKey};
use super::{store_key, Result, RoomInfo, StateChanges, StateStore, StoreError};
use crate::{
use matrix_sdk_base::{
store::{
store_key::{self, EncryptedEvent, StoreKey},
Result as StoreResult, StateChanges, StateStore, StoreError},
RoomInfo,
deserialized_responses::MemberEvent,
media::{MediaRequest, UniqueKey},
};
@@ -56,35 +60,48 @@ pub enum DatabaseType {
}
#[derive(Debug, thiserror::Error)]
pub enum SerializationError {
pub enum SledStoreError {
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error(transparent)]
Encryption(#[from] store_key::Error),
#[error(transparent)]
StoreError(#[from] StoreError),
#[error(transparent)]
TransactionError(#[from] sled::Error),
#[error(transparent)]
Identifier(#[from] ruma::identifiers::Error),
#[error(transparent)]
Task(#[from] tokio::task::JoinError),
}
impl From<TransactionError<SerializationError>> for StoreError {
fn from(e: TransactionError<SerializationError>) -> Self {
impl From<TransactionError<SledStoreError>> for SledStoreError {
fn from(e: TransactionError<SledStoreError>) -> Self {
match e {
TransactionError::Abort(e) => e.into(),
TransactionError::Storage(e) => StoreError::Sled(e),
TransactionError::Storage(e) => SledStoreError::TransactionError(e),
}
}
}
impl From<SerializationError> for StoreError {
fn from(e: SerializationError) -> Self {
match e {
SerializationError::Json(e) => StoreError::Json(e),
SerializationError::Encryption(e) => match e {
impl Into<StoreError> for SledStoreError {
fn into(self) -> StoreError {
match self {
SledStoreError::Json(e) => StoreError::Json(e),
SledStoreError::Identifier(e) => StoreError::Identifier(e),
SledStoreError::Encryption(e) => match e {
store_key::Error::Random(e) => StoreError::Encryption(e.to_string()),
store_key::Error::Serialization(e) => StoreError::Json(e),
store_key::Error::Encryption(e) => StoreError::Encryption(e),
},
SledStoreError::StoreError(e) => e,
_ => StoreError::Backend(anyhow!(self))
}
}
}
type Result<A, E = SledStoreError> = std::result::Result<A, E>;
const ENCODE_SEPARATOR: u8 = 0xff;
trait EncodeKey {
@@ -288,7 +305,7 @@ impl SledStore {
if let DatabaseType::Encrypted(k) = key {
StoreKey::import(passphrase, k).map_err(|_| StoreError::StoreLocked)?
} else {
return Err(StoreError::UnencryptedStore);
return Err(SledStoreError::StoreError(StoreError::UnencryptedStore));
}
} else {
let key = StoreKey::new().map_err::<StoreError, _>(|e| e.into())?;
@@ -309,7 +326,7 @@ impl SledStore {
SledStore::open_helper(db, Some(path), None)
}
fn serialize_event(&self, event: &impl Serialize) -> Result<Vec<u8>, SerializationError> {
fn serialize_event(&self, event: &impl Serialize) -> Result<Vec<u8>, SledStoreError> {
if let Some(key) = &*self.store_key {
let encrypted = key.encrypt(event)?;
Ok(serde_json::to_vec(&encrypted)?)
@@ -321,7 +338,7 @@ impl SledStore {
fn deserialize_event<T: for<'b> Deserialize<'b>>(
&self,
event: &[u8],
) -> Result<T, SerializationError> {
) -> Result<T, SledStoreError> {
if let Some(key) = &*self.store_key {
let encrypted: EncryptedEvent = serde_json::from_slice(event)?;
Ok(key.decrypt(encrypted)?)
@@ -353,7 +370,7 @@ impl SledStore {
pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
let now = Instant::now();
let ret: Result<(), TransactionError<SerializationError>> = (
let ret: Result<(), TransactionError<SledStoreError>> = (
&self.session,
&self.account_data,
&self.members,
@@ -523,7 +540,7 @@ impl SledStore {
ret?;
let ret: Result<(), TransactionError<SerializationError>> =
let ret: Result<(), TransactionError<SledStoreError>> =
(&self.room_user_receipts, &self.room_event_receipts).transaction(
|(room_user_receipts, room_event_receipts)| {
for (room, content) in &changes.receipts {
@@ -651,8 +668,8 @@ impl SledStore {
pub async fn get_user_ids_stream(
&self,
room_id: &RoomId,
) -> Result<impl Stream<Item = Result<Box<UserId>>>> {
let decode = |key: &[u8]| -> Result<Box<UserId>> {
) -> StoreResult<impl Stream<Item = StoreResult<Box<UserId>>>> {
let decode = |key: &[u8]| -> StoreResult<Box<UserId>> {
let mut iter = key.split(|c| c == &ENCODE_SEPARATOR);
// Our key is a the room id separated from the user id by a null
// byte, discard the first value of the split.
@@ -666,41 +683,38 @@ impl SledStore {
let members = self.members.clone();
let key = room_id.encode();
spawn_blocking(move || stream::iter(members.scan_prefix(key).map(move |u| decode(&u?.0))))
.await
.map_err(Into::into)
spawn_blocking(move || stream::iter(members.scan_prefix(key).map(move |u| decode(&u.map_err(|e| StoreError::Backend(anyhow!(e)))?.0))))
.await.map_err(|e| StoreError::Backend(anyhow!(e)))
}
pub async fn get_invited_user_ids(
&self,
room_id: &RoomId,
) -> Result<impl Stream<Item = Result<Box<UserId>>>> {
) -> StoreResult<impl Stream<Item = StoreResult<Box<UserId>>>> {
let db = self.clone();
let key = room_id.encode();
spawn_blocking(move || {
stream::iter(db.invited_user_ids.scan_prefix(key).map(|u| {
UserId::parse(String::from_utf8_lossy(&u?.1).to_string())
UserId::parse(String::from_utf8_lossy(&u.map_err(|e| StoreError::Backend(anyhow!(e)))?.1).to_string())
.map_err(StoreError::Identifier)
}))
})
.await
.map_err(Into::into)
.await.map_err(|e| StoreError::Backend(anyhow!(e)))
}
pub async fn get_joined_user_ids(
&self,
room_id: &RoomId,
) -> Result<impl Stream<Item = Result<Box<UserId>>>> {
) -> StoreResult<impl Stream<Item = StoreResult<Box<UserId>>>> {
let db = self.clone();
let key = room_id.encode();
spawn_blocking(move || {
stream::iter(db.joined_user_ids.scan_prefix(key).map(|u| {
UserId::parse(String::from_utf8_lossy(&u?.1).to_string())
UserId::parse(String::from_utf8_lossy(&u.map_err(|e| StoreError::Backend(anyhow!(e)))?.1).to_string())
.map_err(StoreError::Identifier)
}))
})
.await
.map_err(Into::into)
.await.map_err(|e| StoreError::Backend(anyhow!(e)))
}
pub async fn get_room_infos(&self) -> Result<impl Stream<Item = Result<RoomInfo>>> {
@@ -789,25 +803,26 @@ impl SledStore {
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(Box<UserId>, Receipt)>> {
) -> StoreResult<Vec<(Box<UserId>, Receipt)>> {
let db = self.clone();
let key = (room_id.as_str(), receipt_type.as_ref(), event_id.as_str()).encode();
spawn_blocking(move || {
db.room_event_receipts
.scan_prefix(key)
.map(|u| {
u.map_err(StoreError::Sled).and_then(|(key, value)| {
db.deserialize_event(&value)
// TODO remove this unwrapping
.map(|receipt| {
(decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt)
})
.map_err(Into::into)
})
u
.map_err(|e| StoreError::Backend(anyhow!(e)))
.and_then(|(key, value)| {
db.deserialize_event(&value)
// TODO remove this unwrapping
.map(|receipt| {
(decode_key_value(&key, 3).unwrap().try_into().unwrap(), receipt)
}).map_err(|e| StoreError::Backend(anyhow!(e)))
})
})
.collect()
})
.await?
.await.map_err(|e| StoreError::Backend(anyhow!(e)))?
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
@@ -921,7 +936,7 @@ impl SledStore {
room_event_receipts_batch.remove(key?)
}
let ret: Result<(), TransactionError<SerializationError>> = (
let ret: Result<(), TransactionError<SledStoreError>> = (
&self.members,
&self.profiles,
&self.display_names,
@@ -981,24 +996,24 @@ impl SledStore {
#[async_trait]
impl StateStore for SledStore {
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.save_filter(filter_name, filter_id).await
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> {
self.save_filter(filter_name, filter_id).await.map_err(Into::into)
}
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
self.save_changes(changes).await
async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> {
self.save_changes(changes).await.map_err(Into::into)
}
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>> {
self.get_filter(filter_id).await
async fn get_filter(&self, filter_id: &str) -> StoreResult<Option<String>> {
self.get_filter(filter_id).await.map_err(Into::into)
}
async fn get_sync_token(&self) -> Result<Option<String>> {
self.get_sync_token().await
async fn get_sync_token(&self) -> StoreResult<Option<String>> {
self.get_sync_token().await.map_err(Into::into)
}
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
self.get_presence_event(user_id).await
async fn get_presence_event(&self, user_id: &UserId) -> StoreResult<Option<Raw<PresenceEvent>>> {
self.get_presence_event(user_id).await.map_err(Into::into)
}
async fn get_state_event(
@@ -1006,75 +1021,75 @@ impl StateStore for SledStore {
room_id: &RoomId,
event_type: EventType,
state_key: &str,
) -> Result<Option<Raw<AnySyncStateEvent>>> {
self.get_state_event(room_id, event_type, state_key).await
) -> StoreResult<Option<Raw<AnySyncStateEvent>>> {
self.get_state_event(room_id, event_type, state_key).await.map_err(Into::into)
}
async fn get_state_events(
&self,
room_id: &RoomId,
event_type: EventType,
) -> Result<Vec<Raw<AnySyncStateEvent>>> {
self.get_state_events(room_id, event_type).await
) -> StoreResult<Vec<Raw<AnySyncStateEvent>>> {
self.get_state_events(room_id, event_type).await.map_err(Into::into)
}
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<RoomMemberEventContent>> {
self.get_profile(room_id, user_id).await
) -> StoreResult<Option<RoomMemberEventContent>> {
self.get_profile(room_id, user_id).await.map_err(Into::into)
}
async fn get_member_event(
&self,
room_id: &RoomId,
state_key: &UserId,
) -> Result<Option<MemberEvent>> {
self.get_member_event(room_id, state_key).await
) -> StoreResult<Option<MemberEvent>> {
self.get_member_event(room_id, state_key).await.map_err(Into::into)
}
async fn get_user_ids(&self, room_id: &RoomId) -> Result<Vec<Box<UserId>>> {
async fn get_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<Box<UserId>>> {
self.get_user_ids_stream(room_id).await?.try_collect().await
}
async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result<Vec<Box<UserId>>> {
async fn get_invited_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<Box<UserId>>> {
self.get_invited_user_ids(room_id).await?.try_collect().await
}
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<Box<UserId>>> {
async fn get_joined_user_ids(&self, room_id: &RoomId) -> StoreResult<Vec<Box<UserId>>> {
self.get_joined_user_ids(room_id).await?.try_collect().await
}
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>> {
self.get_room_infos().await?.try_collect().await
async fn get_room_infos(&self) -> StoreResult<Vec<RoomInfo>> {
self.get_room_infos().await.map_err::<StoreError, _>(Into::into)?.try_collect().await.map_err::<StoreError, _>(Into::into)
}
async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>> {
self.get_stripped_room_infos().await?.try_collect().await
async fn get_stripped_room_infos(&self) -> StoreResult<Vec<RoomInfo>> {
self.get_stripped_room_infos().await.map_err::<StoreError, _>(Into::into)?.try_collect().await.map_err::<StoreError, _>(Into::into)
}
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
) -> Result<BTreeSet<Box<UserId>>> {
self.get_users_with_display_name(room_id, display_name).await
) -> StoreResult<BTreeSet<Box<UserId>>> {
self.get_users_with_display_name(room_id, display_name).await.map_err(Into::into)
}
async fn get_account_data_event(
&self,
event_type: EventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
self.get_account_data_event(event_type).await
) -> StoreResult<Option<Raw<AnyGlobalAccountDataEvent>>> {
self.get_account_data_event(event_type).await.map_err(Into::into)
}
async fn get_room_account_data_event(
&self,
room_id: &RoomId,
event_type: EventType,
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
self.get_room_account_data_event(room_id, event_type).await
) -> StoreResult<Option<Raw<AnyRoomAccountDataEvent>>> {
self.get_room_account_data_event(room_id, event_type).await.map_err(Into::into)
}
async fn get_user_room_receipt_event(
@@ -1082,8 +1097,8 @@ impl StateStore for SledStore {
room_id: &RoomId,
receipt_type: ReceiptType,
user_id: &UserId,
) -> Result<Option<(Box<EventId>, Receipt)>> {
self.get_user_room_receipt_event(room_id, receipt_type, user_id).await
) -> StoreResult<Option<(Box<EventId>, Receipt)>> {
self.get_user_room_receipt_event(room_id, receipt_type, user_id).await.map_err(Into::into)
}
async fn get_event_room_receipt_events(
@@ -1091,46 +1106,47 @@ impl StateStore for SledStore {
room_id: &RoomId,
receipt_type: ReceiptType,
event_id: &EventId,
) -> Result<Vec<(Box<UserId>, Receipt)>> {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await
) -> StoreResult<Vec<(Box<UserId>, Receipt)>> {
self.get_event_room_receipt_events(room_id, receipt_type, event_id).await.map_err(Into::into)
}
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.get_custom_value(key).await
async fn get_custom_value(&self, key: &[u8]) -> StoreResult<Option<Vec<u8>>> {
self.get_custom_value(key).await.map_err(Into::into)
}
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
self.set_custom_value(key, value).await
async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> StoreResult<Option<Vec<u8>>> {
self.set_custom_value(key, value).await.map_err(Into::into)
}
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> Result<()> {
self.add_media_content(request, data).await
async fn add_media_content(&self, request: &MediaRequest, data: Vec<u8>) -> StoreResult<()> {
self.add_media_content(request, data).await.map_err(Into::into)
}
async fn get_media_content(&self, request: &MediaRequest) -> Result<Option<Vec<u8>>> {
self.get_media_content(request).await
async fn get_media_content(&self, request: &MediaRequest) -> StoreResult<Option<Vec<u8>>> {
self.get_media_content(request).await.map_err(Into::into)
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> {
self.remove_media_content(request).await
async fn remove_media_content(&self, request: &MediaRequest) -> StoreResult<()> {
self.remove_media_content(request).await.map_err(Into::into)
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
self.remove_media_content_for_uri(uri).await
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> StoreResult<()> {
self.remove_media_content_for_uri(uri).await.map_err(Into::into)
}
async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
self.remove_room(room_id).await
async fn remove_room(&self, room_id: &RoomId) -> StoreResult<()> {
self.remove_room(room_id).await.map_err(Into::into)
}
}
#[cfg(test)]
mod test {
use super::{Result, SledStore};
use super::{StoreResult, SledStore, StateStore};
use matrix_sdk_base::statestore_integration_tests;
async fn get_store() -> Result<SledStore> {
SledStore::open()
async fn get_store() -> StoreResult<impl StateStore> {
SledStore::open().map_err(Into::into)
}
statestore_integration_tests! { integration }