diff --git a/Cargo.lock b/Cargo.lock index 9f3b1018f..6b6aaf3b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2677,6 +2677,7 @@ dependencies = [ name = "matrix-sdk-base" version = "0.6.1" dependencies = [ + "assert_matches", "assign", "async-stream", "async-trait", diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index b3cf57ee9..37a4df13b 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -23,9 +23,10 @@ qrcode = ["matrix-sdk-crypto?/qrcode"] experimental-sliding-sync = ["ruma/unstable-msc3575"] # helpers for testing features build upon this -testing = ["dep:http", "dep:matrix-sdk-test"] +testing = ["dep:http", "dep:matrix-sdk-test", "dep:assert_matches"] [dependencies] +assert_matches = { version = "1.5.0", optional = true } async-stream = { workspace = true } async-trait = { workspace = true } dashmap = { workspace = true } @@ -46,6 +47,7 @@ tracing = { workspace = true } zeroize = { workspace = true, features = ["zeroize_derive"] } [dev-dependencies] +assert_matches = "1.5.0" assign = "1.1.1" ctor = { workspace = true } futures = { version = "0.3.21", default-features = false, features = ["executor"] } diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 9813a0192..c46b3ae02 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -64,7 +64,7 @@ use crate::{ rooms::{Room, RoomInfo, RoomType}, store::{ ambiguity_map::AmbiguityCache, DynStateStore, Result as StoreResult, StateChanges, - StateStoreExt, Store, StoreConfig, + StateStoreDataKey, StateStoreDataValue, StateStoreExt, Store, StoreConfig, }, sync::{JoinedRoom, LeftRoom, Rooms, SyncResponse, Timeline}, Session, SessionMeta, SessionTokens, @@ -1025,7 +1025,13 @@ impl BaseClient { filter_name: &str, response: &api::filter::create_filter::v3::Response, ) -> Result<()> { - Ok(self.store.save_filter(filter_name, &response.filter_id).await?) + Ok(self + .store + .set_kv_data( + StateStoreDataKey::Filter(filter_name), + StateStoreDataValue::Filter(response.filter_id.clone()), + ) + .await?) } /// Get the filter id of a previously uploaded filter. @@ -1040,7 +1046,13 @@ impl BaseClient { /// /// [`receive_filter_upload`]: #method.receive_filter_upload pub async fn get_filter(&self, filter_name: &str) -> StoreResult> { - self.store.get_filter(filter_name).await + let filter = self + .store + .get_kv_data(StateStoreDataKey::Filter(filter_name)) + .await? + .map(|d| d.into_filter().expect("State store data not a filter")); + + Ok(filter) } /// Get a to-device request that will share a room key with users in a room. diff --git a/crates/matrix-sdk-base/src/lib.rs b/crates/matrix-sdk-base/src/lib.rs index 8480ad95e..d411e5663 100644 --- a/crates/matrix-sdk-base/src/lib.rs +++ b/crates/matrix-sdk-base/src/lib.rs @@ -43,7 +43,7 @@ pub use http; pub use matrix_sdk_crypto as crypto; pub use once_cell; pub use rooms::{DisplayName, Room, RoomInfo, RoomMember, RoomType}; -pub use store::{StateChanges, StateStore, StoreError}; +pub use store::{StateChanges, StateStore, StateStoreDataKey, StateStoreDataValue, StoreError}; pub use utils::{ MinimalRoomMemberEvent, MinimalStateEvent, OriginalMinimalStateEvent, RedactedMinimalStateEvent, }; diff --git a/crates/matrix-sdk-base/src/store/integration_tests.rs b/crates/matrix-sdk-base/src/store/integration_tests.rs index 87a6195a5..982ed149a 100644 --- a/crates/matrix-sdk-base/src/store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/store/integration_tests.rs @@ -2,6 +2,7 @@ use std::collections::{BTreeMap, BTreeSet}; +use assert_matches::assert_matches; use async_trait::async_trait; use matrix_sdk_test::test_json; use ruma::{ @@ -34,7 +35,7 @@ use crate::{ deserialized_responses::MemberEvent, media::{MediaFormat, MediaRequest, MediaThumbnailSize}, store::{Result, StateStoreExt}, - RoomInfo, RoomType, StateChanges, + RoomInfo, RoomType, StateChanges, StateStoreDataKey, StateStoreDataValue, }; /// `StateStore` integration tests. @@ -265,7 +266,7 @@ impl StateStoreIntegrationTests for DynStateStore { let room_id = room_id(); self.populate().await?; - assert!(self.get_sync_token().await?.is_some()); + assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some()); assert_eq!( self.get_state_event_static::(room_id) .await? @@ -312,7 +313,7 @@ impl StateStoreIntegrationTests for DynStateStore { let user_id = user_id(); self.populate().await?; - assert!(self.get_sync_token().await?.is_some()); + assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some()); assert!(self.get_presence_event(user_id).await?.is_some()); assert_eq!(self.get_room_infos().await?.len(), 1, "Expected to find 1 room info"); assert_eq!( @@ -401,21 +402,54 @@ impl StateStoreIntegrationTests for DynStateStore { } async fn test_filter_saving(&self) { - let test_name = "filter_name"; + let filter_name = "filter_name"; let filter_id = "filter_id_1234"; - assert_eq!(self.get_filter(test_name).await.unwrap(), None); - self.save_filter(test_name, filter_id).await.unwrap(); - assert_eq!(self.get_filter(test_name).await.unwrap(), Some(filter_id.to_owned())); + + self.set_kv_data( + StateStoreDataKey::Filter(filter_name), + StateStoreDataValue::Filter(filter_id.to_owned()), + ) + .await + .unwrap(); + let stored_filter_id = assert_matches!( + self.get_kv_data(StateStoreDataKey::Filter(filter_name)).await, + Ok(Some(StateStoreDataValue::Filter(s))) => s + ); + assert_eq!(stored_filter_id, filter_id); + + self.remove_kv_data(StateStoreDataKey::Filter(filter_name)).await.unwrap(); + assert_matches!(self.get_kv_data(StateStoreDataKey::Filter(filter_name)).await, Ok(None)); } async fn test_sync_token_saving(&self) { - let mut changes = StateChanges::default(); - let sync_token = "t392-516_47314_0_7_1".to_owned(); + let sync_token_1 = "t392-516_47314_0_7_1"; + let sync_token_2 = "t392-516_47314_0_7_2"; - changes.sync_token = Some(sync_token.clone()); - assert_eq!(self.get_sync_token().await.unwrap(), None); + assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None)); + + let changes = + StateChanges { sync_token: Some(sync_token_1.to_owned()), ..Default::default() }; self.save_changes(&changes).await.unwrap(); - assert_eq!(self.get_sync_token().await.unwrap(), Some(sync_token)); + let stored_sync_token = assert_matches!( + self.get_kv_data(StateStoreDataKey::SyncToken).await, + Ok(Some(StateStoreDataValue::SyncToken(s))) => s + ); + assert_eq!(stored_sync_token, sync_token_1); + + self.set_kv_data( + StateStoreDataKey::SyncToken, + StateStoreDataValue::SyncToken(sync_token_2.to_owned()), + ) + .await + .unwrap(); + let stored_sync_token = assert_matches!( + self.get_kv_data(StateStoreDataKey::SyncToken).await, + Ok(Some(StateStoreDataValue::SyncToken(s))) => s + ); + assert_eq!(stored_sync_token, sync_token_2); + + self.remove_kv_data(StateStoreDataKey::SyncToken).await.unwrap(); + assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None)); } async fn test_stripped_member_saving(&self) { diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index 20a3972bf..14dcc74ef 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -37,7 +37,10 @@ use ruma::{ use tracing::{debug, info, warn}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; -use crate::{deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent}; +use crate::{ + deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent, + StateStoreDataKey, StateStoreDataValue, +}; /// In-Memory, non-persistent implementation of the `StateStore` /// @@ -119,18 +122,48 @@ impl MemoryStore { } } - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.filters.insert(filter_name.to_owned(), filter_id.to_owned()); + async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + match key { + StateStoreDataKey::SyncToken => { + Ok(self.sync_token.read().unwrap().clone().map(StateStoreDataValue::SyncToken)) + } + StateStoreDataKey::Filter(filter_name) => Ok(self + .filters + .get(filter_name) + .map(|f| StateStoreDataValue::Filter(f.value().clone()))), + } + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + match key { + StateStoreDataKey::SyncToken => { + *self.sync_token.write().unwrap() = + Some(value.into_sync_token().expect("Session data not a sync token")) + } + StateStoreDataKey::Filter(filter_name) => { + self.filters.insert( + filter_name.to_owned(), + value.into_filter().expect("Session data not a filter"), + ); + } + } Ok(()) } - async fn get_filter(&self, filter_name: &str) -> Result> { - Ok(self.filters.get(filter_name).map(|f| f.to_string())) - } + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + match key { + StateStoreDataKey::SyncToken => *self.sync_token.write().unwrap() = None, + StateStoreDataKey::Filter(filter_name) => { + self.filters.remove(filter_name); + } + } - async fn get_sync_token(&self) -> Result> { - Ok(self.sync_token.read().unwrap().clone()) + Ok(()) } async fn save_changes(&self, changes: &StateChanges) -> Result<()> { @@ -596,22 +629,26 @@ impl MemoryStore { impl StateStore for MemoryStore { type Error = StoreError; - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.save_filter(filter_name, filter_id).await + async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + self.get_kv_data(key).await + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + self.set_kv_data(key, value).await + } + + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + self.remove_kv_data(key).await } async fn save_changes(&self, changes: &StateChanges) -> Result<()> { self.save_changes(changes).await } - async fn get_filter(&self, filter_id: &str) -> Result> { - self.get_filter(filter_id).await - } - - async fn get_sync_token(&self) -> Result> { - self.get_sync_token().await - } - async fn get_presence_event(&self, user_id: &UserId) -> Result>> { self.get_presence_event(user_id).await } diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index 2bbb93a41..88104e539 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -74,7 +74,10 @@ mod memory_store; pub use self::integration_tests::StateStoreIntegrationTests; pub use self::{ memory_store::MemoryStore, - traits::{DynStateStore, IntoStateStore, StateStore, StateStoreExt}, + traits::{ + DynStateStore, IntoStateStore, StateStore, StateStoreDataKey, StateStoreDataValue, + StateStoreExt, + }, }; /// State store specific error type. @@ -191,7 +194,8 @@ impl Store { self.stripped_rooms.insert(room.room_id().to_owned(), room); } - let token = self.get_sync_token().await?; + let token = + self.get_kv_data(StateStoreDataKey::SyncToken).await?.and_then(|s| s.into_sync_token()); *self.sync_token.write().await = token; self.session_meta.set(session_meta).expect("Session Meta was already set"); diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs index d3968db87..6659b75d5 100644 --- a/crates/matrix-sdk-base/src/store/traits.rs +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -41,30 +41,43 @@ use crate::{ #[cfg_attr(not(target_arch = "wasm32"), async_trait)] pub trait StateStore: AsyncTraitDeps { /// The error type used by this state store. - type Error: fmt::Debug + Into; + type Error: fmt::Debug + Into + From; - /// Save the given filter id under the given name. + /// Get key-value data from the store. /// /// # Arguments /// - /// * `filter_name` - The name that should be used to store the filter id. + /// * `key` - The key to fetch data for. + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result, Self::Error>; + + /// Put key-value data into the store. /// - /// * `filter_id` - The filter id that should be stored in the state store. - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<(), Self::Error>; + /// # Arguments + /// + /// * `key` - The key to identify the data in the store. + /// + /// * `value` - The data to insert. + /// + /// Panics if the key and value variants do not match. + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<(), Self::Error>; + + /// Remove key-value data from the store. + /// + /// # Arguments + /// + /// * `key` - The key to remove the data for. + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error>; /// Save the set of state changes in the store. async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error>; - /// Get the filter id that was stored under the given filter name. - /// - /// # Arguments - /// - /// * `filter_name` - The name that was used to store the filter id. - async fn get_filter(&self, filter_name: &str) -> Result, Self::Error>; - - /// Get the last stored sync token. - async fn get_sync_token(&self) -> Result, Self::Error>; - /// Get the stored presence event for the given user. /// /// # Arguments @@ -308,22 +321,29 @@ impl fmt::Debug for EraseStateStoreError { impl StateStore for EraseStateStoreError { type Error = StoreError; - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<(), Self::Error> { - self.0.save_filter(filter_name, filter_id).await.map_err(Into::into) + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result, Self::Error> { + self.0.get_kv_data(key).await.map_err(Into::into) + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<(), Self::Error> { + self.0.set_kv_data(key, value).await.map_err(Into::into) + } + + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> { + self.0.remove_kv_data(key).await.map_err(Into::into) } async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> { self.0.save_changes(changes).await.map_err(Into::into) } - async fn get_filter(&self, filter_name: &str) -> Result, Self::Error> { - self.0.get_filter(filter_name).await.map_err(Into::into) - } - - async fn get_sync_token(&self) -> Result, Self::Error> { - self.0.get_sync_token().await.map_err(Into::into) - } - async fn get_presence_event( &self, user_id: &UserId, @@ -605,3 +625,51 @@ where unsafe { Arc::from_raw(ptr_erased) } } } + +/// A value for key-value data that should be persisted into the store. +#[derive(Debug, Clone)] +pub enum StateStoreDataValue { + /// The sync token. + SyncToken(String), + + /// A filter with the given ID. + Filter(String), +} + +impl StateStoreDataValue { + /// Get this value if it is a sync token. + pub fn into_sync_token(self) -> Option { + match self { + Self::SyncToken(token) => Some(token), + Self::Filter(_) => None, + } + } + + /// Get this value if it is a filter. + pub fn into_filter(self) -> Option { + match self { + Self::SyncToken(_) => None, + Self::Filter(id) => Some(id), + } + } +} + +/// A key for key-value data. +#[derive(Debug, Clone, Copy)] +pub enum StateStoreDataKey<'a> { + /// The sync token. + SyncToken, + + /// A filter with the given name. + Filter(&'a str), +} + +impl<'a> StateStoreDataKey<'a> { + /// The string to use to encode this key. + pub const fn encoding_key(&self) -> &str { + match self { + Self::SyncToken => "sync_token", + Self::Filter(_) => "filter", + } + } +} diff --git a/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs b/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs index fbee5c291..ca8e76f37 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs @@ -20,6 +20,7 @@ use std::{ use gloo_utils::format::JsValueSerdeExt; use indexed_db_futures::{prelude::*, request::OpenDbRequest, IdbDatabase, IdbVersionChangeEvent}; use js_sys::Date as JsDate; +use matrix_sdk_base::StateStoreDataKey; use matrix_sdk_store_encryption::StoreCipher; use serde::{Deserialize, Serialize}; use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue}; @@ -376,12 +377,23 @@ async fn migrate_to_v4( let sync_token = sync_token_store.get(&JsValue::from_str(OLD_KEYS::SYNC_TOKEN))?.await?; if let Some(sync_token) = sync_token { - values.push((encode_key(store_cipher, KEYS::SYNC_TOKEN, KEYS::SYNC_TOKEN), sync_token)); + values.push(( + encode_key( + store_cipher, + StateStoreDataKey::SyncToken.encoding_key(), + StateStoreDataKey::SyncToken.encoding_key(), + ), + sync_token, + )); } // Filters let session_store = tx.object_store(OLD_KEYS::SESSION)?; - let range = encode_to_range(store_cipher, KEYS::FILTER, KEYS::FILTER)?; + let range = encode_to_range( + store_cipher, + StateStoreDataKey::Filter("").encoding_key(), + StateStoreDataKey::Filter("").encoding_key(), + )?; if let Some(cursor) = session_store.open_cursor_with_range(&range)?.await? { while let Some(key) = cursor.key() { let value = cursor.value(); @@ -410,7 +422,7 @@ mod tests { use assert_matches::assert_matches; use indexed_db_futures::prelude::*; - use matrix_sdk_base::{StateStore, StoreError}; + use matrix_sdk_base::{StateStore, StateStoreDataKey, StoreError}; use matrix_sdk_test::async_test; use ruma::{ events::{AnySyncStateEvent, StateEventType}, @@ -701,11 +713,19 @@ mod tests { let session_store = tx.object_store(OLD_KEYS::SESSION)?; session_store.put_key_val( - &encode_key(None, KEYS::FILTER, (KEYS::FILTER, filter_1)), + &encode_key( + None, + StateStoreDataKey::Filter("").encoding_key(), + (StateStoreDataKey::Filter("").encoding_key(), filter_1), + ), &serialize_event(None, &filter_1_id)?, )?; session_store.put_key_val( - &encode_key(None, KEYS::FILTER, (KEYS::FILTER, filter_2)), + &encode_key( + None, + StateStoreDataKey::Filter("").encoding_key(), + (StateStoreDataKey::Filter("").encoding_key(), filter_2), + ), &serialize_event(None, &filter_2_id)?, )?; @@ -716,13 +736,28 @@ mod tests { // this transparently migrates to the latest version let store = IndexeddbStateStore::builder().name(name).build().await?; - let stored_sync_token = store.get_sync_token().await.unwrap().unwrap(); + let stored_sync_token = store + .get_kv_data(StateStoreDataKey::SyncToken) + .await? + .unwrap() + .into_sync_token() + .unwrap(); assert_eq!(stored_sync_token, sync_token); - let stored_filter_1_id = store.get_filter(filter_1).await.unwrap().unwrap(); + let stored_filter_1_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_1)) + .await? + .unwrap() + .into_filter() + .unwrap(); assert_eq!(stored_filter_1_id, filter_1_id); - let stored_filter_2_id = store.get_filter(filter_2).await.unwrap().unwrap(); + let stored_filter_2_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_2)) + .await? + .unwrap() + .into_filter() + .unwrap(); assert_eq!(stored_filter_2_id, filter_2_id); Ok(()) diff --git a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs index 2f4fe133a..cb615423f 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs @@ -25,7 +25,7 @@ use matrix_sdk_base::{ deserialized_responses::RawMemberEvent, media::{MediaRequest, UniqueKey}, store::{StateChanges, StateStore, StoreError}, - MinimalStateEvent, RoomInfo, + MinimalStateEvent, RoomInfo, StateStoreDataKey, StateStoreDataValue, }; use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher}; use ruma::{ @@ -147,8 +147,6 @@ mod KEYS { // static keys pub const STORE_KEY: &str = "store_key"; - pub const FILTER: &str = "filter"; - pub const SYNC_TOKEN: &str = "sync_token"; } pub use KEYS::ALL_STORES; @@ -413,6 +411,15 @@ impl IndexeddbStateStore { .map(|f| self.deserialize_event(f)) .transpose() } + + fn encode_kv_data_key(&self, key: StateStoreDataKey<'_>) -> JsValue { + match key { + StateStoreDataKey::SyncToken => self.encode_key(key.encoding_key(), key.encoding_key()), + StateStoreDataKey::Filter(filter_name) => { + self.encode_key(key.encoding_key(), (key.encoding_key(), filter_name)) + } + } + } } // Small hack to have the following macro invocation act as the appropriate @@ -445,41 +452,71 @@ macro_rules! impl_state_store { } impl_state_store! { - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result> { + let encoded_key = self.encode_kv_data_key(key); + + let value = self + .inner + .transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readonly)? + .object_store(KEYS::KV)? + .get(&encoded_key)? + .await? + .map(|f| self.deserialize_event::(f)) + .transpose()?; + + let value = match key { + StateStoreDataKey::SyncToken => value.map(StateStoreDataValue::SyncToken), + StateStoreDataKey::Filter(_) => value.map(StateStoreDataValue::Filter), + }; + + Ok(value) + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); + + let value = match key { + StateStoreDataKey::SyncToken => { + value.into_sync_token().expect("Session data not a sync token") + } + StateStoreDataKey::Filter(_) => { + value.into_filter().expect("Session data not a filter") + } + }; + let tx = self .inner .transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readwrite)?; let obj = tx.object_store(KEYS::KV)?; - obj.put_key_val( - &self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)), - &self.serialize_event(&filter_id)?, - )?; + obj.put_key_val(&encoded_key, &self.serialize_event(&value)?)?; tx.await.into_result()?; Ok(()) } - async fn get_filter(&self, filter_name: &str) -> Result> { - self.inner - .transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readonly)? - .object_store(KEYS::KV)? - .get(&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); - async fn get_sync_token(&self) -> Result> { - self.inner - .transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readonly)? - .object_store(KEYS::KV)? - .get(&self.encode_key(KEYS::SYNC_TOKEN, KEYS::SYNC_TOKEN))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() + let tx = self + .inner + .transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readwrite)?; + let obj = tx.object_store(KEYS::KV)?; + + obj.delete(&encoded_key)?; + + tx.await.into_result()?; + + Ok(()) } async fn save_changes(&self, changes: &StateChanges) -> Result<()> { @@ -543,8 +580,10 @@ impl_state_store! { self.inner.transaction_on_multi_with_mode(&stores, IdbTransactionMode::Readwrite)?; if let Some(s) = &changes.sync_token { - tx.object_store(KEYS::KV)? - .put_key_val(&self.encode_key(KEYS::SYNC_TOKEN, KEYS::SYNC_TOKEN), &self.serialize_event(s)?)?; + tx.object_store(KEYS::KV)?.put_key_val( + &self.encode_kv_data_key(StateStoreDataKey::SyncToken), + &self.serialize_event(s)?, + )?; } if !changes.ambiguity_maps.is_empty() { diff --git a/crates/matrix-sdk-sled/src/state_store/migrations.rs b/crates/matrix-sdk-sled/src/state_store/migrations.rs index 2c996e8a6..0f3df7541 100644 --- a/crates/matrix-sdk-sled/src/state_store/migrations.rs +++ b/crates/matrix-sdk-sled/src/state_store/migrations.rs @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use matrix_sdk_base::store::{Result as StoreResult, StoreError}; +use matrix_sdk_base::{ + store::{Result as StoreResult, StoreError}, + StateStoreDataKey, +}; use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue}; use sled::{transaction::TransactionError, Batch, Transactional, Tree}; use tracing::debug; @@ -167,13 +170,13 @@ impl SledStateStore { let mut batch = sled::Batch::default(); // Sync token - let sync_token = session.get(keys::SYNC_TOKEN.encode())?; + let sync_token = session.get(StateStoreDataKey::SyncToken.encoding_key().encode())?; if let Some(sync_token) = sync_token { - batch.insert(keys::SYNC_TOKEN.encode(), sync_token); + batch.insert(StateStoreDataKey::SyncToken.encoding_key().encode(), sync_token); } // Filters - let key = self.encode_key(keys::SESSION, keys::FILTER); + let key = self.encode_key(keys::SESSION, StateStoreDataKey::Filter("").encoding_key()); for res in session.scan_prefix(key) { let (key, value) = res?; batch.insert(key, value); @@ -222,6 +225,7 @@ pub const V1_DB_STORES: &[&str] = &[ #[cfg(test)] mod test { + use matrix_sdk_base::StateStoreDataKey; use matrix_sdk_test::async_test; use ruma::{ events::{AnySyncStateEvent, StateEventType}, @@ -391,13 +395,22 @@ mod test { let session = store.inner.open_tree(old_keys::SESSION).unwrap(); let mut batch = sled::Batch::default(); - batch.insert(keys::SYNC_TOKEN.encode(), store.serialize_value(&sync_token).unwrap()); batch.insert( - store.encode_key(keys::SESSION, (keys::FILTER, filter_1)), + StateStoreDataKey::SyncToken.encoding_key().encode(), + store.serialize_value(&sync_token).unwrap(), + ); + batch.insert( + store.encode_key( + keys::SESSION, + (StateStoreDataKey::Filter("").encoding_key(), filter_1), + ), store.serialize_value(&filter_1_id).unwrap(), ); batch.insert( - store.encode_key(keys::SESSION, (keys::FILTER, filter_2)), + store.encode_key( + keys::SESSION, + (StateStoreDataKey::Filter("").encoding_key(), filter_2), + ), store.serialize_value(&filter_2_id).unwrap(), ); session.apply_batch(batch).unwrap(); @@ -412,13 +425,31 @@ mod test { .build() .unwrap(); - let stored_sync_token = store.get_sync_token().await.unwrap().unwrap(); + let stored_sync_token = store + .get_kv_data(StateStoreDataKey::SyncToken) + .await + .unwrap() + .unwrap() + .into_sync_token() + .unwrap(); assert_eq!(stored_sync_token, sync_token); - let stored_filter_1_id = store.get_filter(filter_1).await.unwrap().unwrap(); + let stored_filter_1_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_1)) + .await + .unwrap() + .unwrap() + .into_filter() + .unwrap(); assert_eq!(stored_filter_1_id, filter_1_id); - let stored_filter_2_id = store.get_filter(filter_2).await.unwrap().unwrap(); + let stored_filter_2_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_2)) + .await + .unwrap() + .unwrap() + .into_filter() + .unwrap(); assert_eq!(stored_filter_2_id, filter_2_id); } } diff --git a/crates/matrix-sdk-sled/src/state_store/mod.rs b/crates/matrix-sdk-sled/src/state_store/mod.rs index 0afb3a878..9fbd2b724 100644 --- a/crates/matrix-sdk-sled/src/state_store/mod.rs +++ b/crates/matrix-sdk-sled/src/state_store/mod.rs @@ -26,7 +26,7 @@ use matrix_sdk_base::{ deserialized_responses::RawMemberEvent, media::{MediaRequest, UniqueKey}, store::{Result as StoreResult, StateChanges, StateStore, StoreError}, - MinimalStateEvent, RoomInfo, + MinimalStateEvent, RoomInfo, StateStoreDataKey, StateStoreDataValue, }; use matrix_sdk_store_encryption::{Error as KeyEncryptionError, StoreCipher}; use ruma::{ @@ -104,9 +104,7 @@ impl From for StoreError { mod keys { // Static keys - pub const SYNC_TOKEN: &str = "sync_token"; pub const SESSION: &str = "session"; - pub const FILTER: &str = "filter"; // Stores pub const ACCOUNT_DATA: &str = "account-data"; @@ -411,23 +409,54 @@ impl SledStateStore { } } - pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.kv.insert( - self.encode_key(keys::SESSION, (keys::FILTER, filter_name)), - self.serialize_value(&filter_id)?, - )?; + fn encode_kv_data_key(&self, key: StateStoreDataKey<'_>) -> Vec { + match key { + StateStoreDataKey::SyncToken => key.encoding_key().encode(), + StateStoreDataKey::Filter(filter_name) => { + self.encode_key(keys::SESSION, (key.encoding_key(), filter_name)) + } + } + } + + async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + let encoded_key = self.encode_kv_data_key(key); + + let value = + self.kv.get(encoded_key)?.map(|e| self.deserialize_value::(&e)).transpose()?; + + let value = match key { + StateStoreDataKey::SyncToken => value.map(StateStoreDataValue::SyncToken), + StateStoreDataKey::Filter(_) => value.map(StateStoreDataValue::Filter), + }; + + Ok(value) + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); + + let value = match key { + StateStoreDataKey::SyncToken => { + value.into_sync_token().expect("Session data not a sync token") + } + StateStoreDataKey::Filter(_) => value.into_filter().expect("Session data not a filter"), + }; + + self.kv.insert(encoded_key, self.serialize_value(&value)?)?; + Ok(()) } - pub async fn get_filter(&self, filter_name: &str) -> Result> { - self.kv - .get(self.encode_key(keys::SESSION, (keys::FILTER, filter_name)))? - .map(|f| self.deserialize_value(&f)) - .transpose() - } + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); - pub async fn get_sync_token(&self) -> Result> { - self.kv.get(keys::SYNC_TOKEN.encode())?.map(|t| self.deserialize_value(&t)).transpose() + self.kv.remove(encoded_key)?; + + Ok(()) } pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { @@ -792,7 +821,7 @@ impl SledStateStore { .transaction(|(kv, account_data)| { if let Some(s) = &changes.sync_token { kv.insert( - keys::SYNC_TOKEN.encode(), + self.encode_kv_data_key(StateStoreDataKey::SyncToken), self.serialize_value(s).map_err(ConflictableTransactionError::Abort)?, )?; } @@ -1323,22 +1352,29 @@ impl SledStateStore { impl StateStore for SledStateStore { type Error = StoreError; - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> { - self.save_filter(filter_name, filter_id).await.map_err(Into::into) + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> StoreResult> { + self.get_kv_data(key).await.map_err(Into::into) + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> StoreResult<()> { + self.set_kv_data(key, value).await.map_err(Into::into) + } + + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> StoreResult<()> { + self.remove_kv_data(key).await.map_err(Into::into) } async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> { self.save_changes(changes).await.map_err(Into::into) } - async fn get_filter(&self, filter_id: &str) -> StoreResult> { - self.get_filter(filter_id).await.map_err(Into::into) - } - - async fn get_sync_token(&self) -> StoreResult> { - self.get_sync_token().await.map_err(Into::into) - } - async fn get_presence_event( &self, user_id: &UserId,