state-store: Replace methods to get and set key-value data

Avoid to make the API bigger when we want to add new key-value data
in the store.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2023-03-03 19:25:10 +01:00
committed by Jonas Platte
parent 5b28c69cfc
commit d77efd7327
12 changed files with 433 additions and 134 deletions

1
Cargo.lock generated
View File

@@ -2677,6 +2677,7 @@ dependencies = [
name = "matrix-sdk-base"
version = "0.6.1"
dependencies = [
"assert_matches",
"assign",
"async-stream",
"async-trait",

View File

@@ -23,9 +23,10 @@ qrcode = ["matrix-sdk-crypto?/qrcode"]
experimental-sliding-sync = ["ruma/unstable-msc3575"]
# helpers for testing features build upon this
testing = ["dep:http", "dep:matrix-sdk-test"]
testing = ["dep:http", "dep:matrix-sdk-test", "dep:assert_matches"]
[dependencies]
assert_matches = { version = "1.5.0", optional = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
dashmap = { workspace = true }
@@ -46,6 +47,7 @@ tracing = { workspace = true }
zeroize = { workspace = true, features = ["zeroize_derive"] }
[dev-dependencies]
assert_matches = "1.5.0"
assign = "1.1.1"
ctor = { workspace = true }
futures = { version = "0.3.21", default-features = false, features = ["executor"] }

View File

@@ -64,7 +64,7 @@ use crate::{
rooms::{Room, RoomInfo, RoomType},
store::{
ambiguity_map::AmbiguityCache, DynStateStore, Result as StoreResult, StateChanges,
StateStoreExt, Store, StoreConfig,
StateStoreDataKey, StateStoreDataValue, StateStoreExt, Store, StoreConfig,
},
sync::{JoinedRoom, LeftRoom, Rooms, SyncResponse, Timeline},
Session, SessionMeta, SessionTokens,
@@ -1025,7 +1025,13 @@ impl BaseClient {
filter_name: &str,
response: &api::filter::create_filter::v3::Response,
) -> Result<()> {
Ok(self.store.save_filter(filter_name, &response.filter_id).await?)
Ok(self
.store
.set_kv_data(
StateStoreDataKey::Filter(filter_name),
StateStoreDataValue::Filter(response.filter_id.clone()),
)
.await?)
}
/// Get the filter id of a previously uploaded filter.
@@ -1040,7 +1046,13 @@ impl BaseClient {
///
/// [`receive_filter_upload`]: #method.receive_filter_upload
pub async fn get_filter(&self, filter_name: &str) -> StoreResult<Option<String>> {
self.store.get_filter(filter_name).await
let filter = self
.store
.get_kv_data(StateStoreDataKey::Filter(filter_name))
.await?
.map(|d| d.into_filter().expect("State store data not a filter"));
Ok(filter)
}
/// Get a to-device request that will share a room key with users in a room.

View File

@@ -43,7 +43,7 @@ pub use http;
pub use matrix_sdk_crypto as crypto;
pub use once_cell;
pub use rooms::{DisplayName, Room, RoomInfo, RoomMember, RoomType};
pub use store::{StateChanges, StateStore, StoreError};
pub use store::{StateChanges, StateStore, StateStoreDataKey, StateStoreDataValue, StoreError};
pub use utils::{
MinimalRoomMemberEvent, MinimalStateEvent, OriginalMinimalStateEvent, RedactedMinimalStateEvent,
};

View File

@@ -2,6 +2,7 @@
use std::collections::{BTreeMap, BTreeSet};
use assert_matches::assert_matches;
use async_trait::async_trait;
use matrix_sdk_test::test_json;
use ruma::{
@@ -34,7 +35,7 @@ use crate::{
deserialized_responses::MemberEvent,
media::{MediaFormat, MediaRequest, MediaThumbnailSize},
store::{Result, StateStoreExt},
RoomInfo, RoomType, StateChanges,
RoomInfo, RoomType, StateChanges, StateStoreDataKey, StateStoreDataValue,
};
/// `StateStore` integration tests.
@@ -265,7 +266,7 @@ impl StateStoreIntegrationTests for DynStateStore {
let room_id = room_id();
self.populate().await?;
assert!(self.get_sync_token().await?.is_some());
assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some());
assert_eq!(
self.get_state_event_static::<RoomTopicEventContent>(room_id)
.await?
@@ -312,7 +313,7 @@ impl StateStoreIntegrationTests for DynStateStore {
let user_id = user_id();
self.populate().await?;
assert!(self.get_sync_token().await?.is_some());
assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some());
assert!(self.get_presence_event(user_id).await?.is_some());
assert_eq!(self.get_room_infos().await?.len(), 1, "Expected to find 1 room info");
assert_eq!(
@@ -401,21 +402,54 @@ impl StateStoreIntegrationTests for DynStateStore {
}
async fn test_filter_saving(&self) {
let test_name = "filter_name";
let filter_name = "filter_name";
let filter_id = "filter_id_1234";
assert_eq!(self.get_filter(test_name).await.unwrap(), None);
self.save_filter(test_name, filter_id).await.unwrap();
assert_eq!(self.get_filter(test_name).await.unwrap(), Some(filter_id.to_owned()));
self.set_kv_data(
StateStoreDataKey::Filter(filter_name),
StateStoreDataValue::Filter(filter_id.to_owned()),
)
.await
.unwrap();
let stored_filter_id = assert_matches!(
self.get_kv_data(StateStoreDataKey::Filter(filter_name)).await,
Ok(Some(StateStoreDataValue::Filter(s))) => s
);
assert_eq!(stored_filter_id, filter_id);
self.remove_kv_data(StateStoreDataKey::Filter(filter_name)).await.unwrap();
assert_matches!(self.get_kv_data(StateStoreDataKey::Filter(filter_name)).await, Ok(None));
}
async fn test_sync_token_saving(&self) {
let mut changes = StateChanges::default();
let sync_token = "t392-516_47314_0_7_1".to_owned();
let sync_token_1 = "t392-516_47314_0_7_1";
let sync_token_2 = "t392-516_47314_0_7_2";
changes.sync_token = Some(sync_token.clone());
assert_eq!(self.get_sync_token().await.unwrap(), None);
assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None));
let changes =
StateChanges { sync_token: Some(sync_token_1.to_owned()), ..Default::default() };
self.save_changes(&changes).await.unwrap();
assert_eq!(self.get_sync_token().await.unwrap(), Some(sync_token));
let stored_sync_token = assert_matches!(
self.get_kv_data(StateStoreDataKey::SyncToken).await,
Ok(Some(StateStoreDataValue::SyncToken(s))) => s
);
assert_eq!(stored_sync_token, sync_token_1);
self.set_kv_data(
StateStoreDataKey::SyncToken,
StateStoreDataValue::SyncToken(sync_token_2.to_owned()),
)
.await
.unwrap();
let stored_sync_token = assert_matches!(
self.get_kv_data(StateStoreDataKey::SyncToken).await,
Ok(Some(StateStoreDataValue::SyncToken(s))) => s
);
assert_eq!(stored_sync_token, sync_token_2);
self.remove_kv_data(StateStoreDataKey::SyncToken).await.unwrap();
assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None));
}
async fn test_stripped_member_saving(&self) {

View File

@@ -37,7 +37,10 @@ use ruma::{
use tracing::{debug, info, warn};
use super::{Result, RoomInfo, StateChanges, StateStore, StoreError};
use crate::{deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent};
use crate::{
deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent,
StateStoreDataKey, StateStoreDataValue,
};
/// In-Memory, non-persistent implementation of the `StateStore`
///
@@ -119,18 +122,48 @@ impl MemoryStore {
}
}
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.filters.insert(filter_name.to_owned(), filter_id.to_owned());
async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
match key {
StateStoreDataKey::SyncToken => {
Ok(self.sync_token.read().unwrap().clone().map(StateStoreDataValue::SyncToken))
}
StateStoreDataKey::Filter(filter_name) => Ok(self
.filters
.get(filter_name)
.map(|f| StateStoreDataValue::Filter(f.value().clone()))),
}
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<()> {
match key {
StateStoreDataKey::SyncToken => {
*self.sync_token.write().unwrap() =
Some(value.into_sync_token().expect("Session data not a sync token"))
}
StateStoreDataKey::Filter(filter_name) => {
self.filters.insert(
filter_name.to_owned(),
value.into_filter().expect("Session data not a filter"),
);
}
}
Ok(())
}
async fn get_filter(&self, filter_name: &str) -> Result<Option<String>> {
Ok(self.filters.get(filter_name).map(|f| f.to_string()))
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
match key {
StateStoreDataKey::SyncToken => *self.sync_token.write().unwrap() = None,
StateStoreDataKey::Filter(filter_name) => {
self.filters.remove(filter_name);
}
}
async fn get_sync_token(&self) -> Result<Option<String>> {
Ok(self.sync_token.read().unwrap().clone())
Ok(())
}
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
@@ -596,22 +629,26 @@ impl MemoryStore {
impl StateStore for MemoryStore {
type Error = StoreError;
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.save_filter(filter_name, filter_id).await
async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
self.get_kv_data(key).await
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<()> {
self.set_kv_data(key, value).await
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
self.remove_kv_data(key).await
}
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
self.save_changes(changes).await
}
async fn get_filter(&self, filter_id: &str) -> Result<Option<String>> {
self.get_filter(filter_id).await
}
async fn get_sync_token(&self) -> Result<Option<String>> {
self.get_sync_token().await
}
async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
self.get_presence_event(user_id).await
}

View File

@@ -74,7 +74,10 @@ mod memory_store;
pub use self::integration_tests::StateStoreIntegrationTests;
pub use self::{
memory_store::MemoryStore,
traits::{DynStateStore, IntoStateStore, StateStore, StateStoreExt},
traits::{
DynStateStore, IntoStateStore, StateStore, StateStoreDataKey, StateStoreDataValue,
StateStoreExt,
},
};
/// State store specific error type.
@@ -191,7 +194,8 @@ impl Store {
self.stripped_rooms.insert(room.room_id().to_owned(), room);
}
let token = self.get_sync_token().await?;
let token =
self.get_kv_data(StateStoreDataKey::SyncToken).await?.and_then(|s| s.into_sync_token());
*self.sync_token.write().await = token;
self.session_meta.set(session_meta).expect("Session Meta was already set");

View File

@@ -41,30 +41,43 @@ use crate::{
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait StateStore: AsyncTraitDeps {
/// The error type used by this state store.
type Error: fmt::Debug + Into<StoreError>;
type Error: fmt::Debug + Into<StoreError> + From<serde_json::Error>;
/// Save the given filter id under the given name.
/// Get key-value data from the store.
///
/// # Arguments
///
/// * `filter_name` - The name that should be used to store the filter id.
/// * `key` - The key to fetch data for.
async fn get_kv_data(
&self,
key: StateStoreDataKey<'_>,
) -> Result<Option<StateStoreDataValue>, Self::Error>;
/// Put key-value data into the store.
///
/// * `filter_id` - The filter id that should be stored in the state store.
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<(), Self::Error>;
/// # Arguments
///
/// * `key` - The key to identify the data in the store.
///
/// * `value` - The data to insert.
///
/// Panics if the key and value variants do not match.
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<(), Self::Error>;
/// Remove key-value data from the store.
///
/// # Arguments
///
/// * `key` - The key to remove the data for.
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error>;
/// Save the set of state changes in the store.
async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error>;
/// Get the filter id that was stored under the given filter name.
///
/// # Arguments
///
/// * `filter_name` - The name that was used to store the filter id.
async fn get_filter(&self, filter_name: &str) -> Result<Option<String>, Self::Error>;
/// Get the last stored sync token.
async fn get_sync_token(&self) -> Result<Option<String>, Self::Error>;
/// Get the stored presence event for the given user.
///
/// # Arguments
@@ -308,22 +321,29 @@ impl<T: fmt::Debug> fmt::Debug for EraseStateStoreError<T> {
impl<T: StateStore> StateStore for EraseStateStoreError<T> {
type Error = StoreError;
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<(), Self::Error> {
self.0.save_filter(filter_name, filter_id).await.map_err(Into::into)
async fn get_kv_data(
&self,
key: StateStoreDataKey<'_>,
) -> Result<Option<StateStoreDataValue>, Self::Error> {
self.0.get_kv_data(key).await.map_err(Into::into)
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<(), Self::Error> {
self.0.set_kv_data(key, value).await.map_err(Into::into)
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
self.0.remove_kv_data(key).await.map_err(Into::into)
}
async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
self.0.save_changes(changes).await.map_err(Into::into)
}
async fn get_filter(&self, filter_name: &str) -> Result<Option<String>, Self::Error> {
self.0.get_filter(filter_name).await.map_err(Into::into)
}
async fn get_sync_token(&self) -> Result<Option<String>, Self::Error> {
self.0.get_sync_token().await.map_err(Into::into)
}
async fn get_presence_event(
&self,
user_id: &UserId,
@@ -605,3 +625,51 @@ where
unsafe { Arc::from_raw(ptr_erased) }
}
}
/// A value for key-value data that should be persisted into the store.
#[derive(Debug, Clone)]
pub enum StateStoreDataValue {
/// The sync token.
SyncToken(String),
/// A filter with the given ID.
Filter(String),
}
impl StateStoreDataValue {
/// Get this value if it is a sync token.
pub fn into_sync_token(self) -> Option<String> {
match self {
Self::SyncToken(token) => Some(token),
Self::Filter(_) => None,
}
}
/// Get this value if it is a filter.
pub fn into_filter(self) -> Option<String> {
match self {
Self::SyncToken(_) => None,
Self::Filter(id) => Some(id),
}
}
}
/// A key for key-value data.
#[derive(Debug, Clone, Copy)]
pub enum StateStoreDataKey<'a> {
/// The sync token.
SyncToken,
/// A filter with the given name.
Filter(&'a str),
}
impl<'a> StateStoreDataKey<'a> {
/// The string to use to encode this key.
pub const fn encoding_key(&self) -> &str {
match self {
Self::SyncToken => "sync_token",
Self::Filter(_) => "filter",
}
}
}

View File

@@ -20,6 +20,7 @@ use std::{
use gloo_utils::format::JsValueSerdeExt;
use indexed_db_futures::{prelude::*, request::OpenDbRequest, IdbDatabase, IdbVersionChangeEvent};
use js_sys::Date as JsDate;
use matrix_sdk_base::StateStoreDataKey;
use matrix_sdk_store_encryption::StoreCipher;
use serde::{Deserialize, Serialize};
use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue};
@@ -376,12 +377,23 @@ async fn migrate_to_v4(
let sync_token = sync_token_store.get(&JsValue::from_str(OLD_KEYS::SYNC_TOKEN))?.await?;
if let Some(sync_token) = sync_token {
values.push((encode_key(store_cipher, KEYS::SYNC_TOKEN, KEYS::SYNC_TOKEN), sync_token));
values.push((
encode_key(
store_cipher,
StateStoreDataKey::SyncToken.encoding_key(),
StateStoreDataKey::SyncToken.encoding_key(),
),
sync_token,
));
}
// Filters
let session_store = tx.object_store(OLD_KEYS::SESSION)?;
let range = encode_to_range(store_cipher, KEYS::FILTER, KEYS::FILTER)?;
let range = encode_to_range(
store_cipher,
StateStoreDataKey::Filter("").encoding_key(),
StateStoreDataKey::Filter("").encoding_key(),
)?;
if let Some(cursor) = session_store.open_cursor_with_range(&range)?.await? {
while let Some(key) = cursor.key() {
let value = cursor.value();
@@ -410,7 +422,7 @@ mod tests {
use assert_matches::assert_matches;
use indexed_db_futures::prelude::*;
use matrix_sdk_base::{StateStore, StoreError};
use matrix_sdk_base::{StateStore, StateStoreDataKey, StoreError};
use matrix_sdk_test::async_test;
use ruma::{
events::{AnySyncStateEvent, StateEventType},
@@ -701,11 +713,19 @@ mod tests {
let session_store = tx.object_store(OLD_KEYS::SESSION)?;
session_store.put_key_val(
&encode_key(None, KEYS::FILTER, (KEYS::FILTER, filter_1)),
&encode_key(
None,
StateStoreDataKey::Filter("").encoding_key(),
(StateStoreDataKey::Filter("").encoding_key(), filter_1),
),
&serialize_event(None, &filter_1_id)?,
)?;
session_store.put_key_val(
&encode_key(None, KEYS::FILTER, (KEYS::FILTER, filter_2)),
&encode_key(
None,
StateStoreDataKey::Filter("").encoding_key(),
(StateStoreDataKey::Filter("").encoding_key(), filter_2),
),
&serialize_event(None, &filter_2_id)?,
)?;
@@ -716,13 +736,28 @@ mod tests {
// this transparently migrates to the latest version
let store = IndexeddbStateStore::builder().name(name).build().await?;
let stored_sync_token = store.get_sync_token().await.unwrap().unwrap();
let stored_sync_token = store
.get_kv_data(StateStoreDataKey::SyncToken)
.await?
.unwrap()
.into_sync_token()
.unwrap();
assert_eq!(stored_sync_token, sync_token);
let stored_filter_1_id = store.get_filter(filter_1).await.unwrap().unwrap();
let stored_filter_1_id = store
.get_kv_data(StateStoreDataKey::Filter(filter_1))
.await?
.unwrap()
.into_filter()
.unwrap();
assert_eq!(stored_filter_1_id, filter_1_id);
let stored_filter_2_id = store.get_filter(filter_2).await.unwrap().unwrap();
let stored_filter_2_id = store
.get_kv_data(StateStoreDataKey::Filter(filter_2))
.await?
.unwrap()
.into_filter()
.unwrap();
assert_eq!(stored_filter_2_id, filter_2_id);
Ok(())

View File

@@ -25,7 +25,7 @@ use matrix_sdk_base::{
deserialized_responses::RawMemberEvent,
media::{MediaRequest, UniqueKey},
store::{StateChanges, StateStore, StoreError},
MinimalStateEvent, RoomInfo,
MinimalStateEvent, RoomInfo, StateStoreDataKey, StateStoreDataValue,
};
use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher};
use ruma::{
@@ -147,8 +147,6 @@ mod KEYS {
// static keys
pub const STORE_KEY: &str = "store_key";
pub const FILTER: &str = "filter";
pub const SYNC_TOKEN: &str = "sync_token";
}
pub use KEYS::ALL_STORES;
@@ -413,6 +411,15 @@ impl IndexeddbStateStore {
.map(|f| self.deserialize_event(f))
.transpose()
}
fn encode_kv_data_key(&self, key: StateStoreDataKey<'_>) -> JsValue {
match key {
StateStoreDataKey::SyncToken => self.encode_key(key.encoding_key(), key.encoding_key()),
StateStoreDataKey::Filter(filter_name) => {
self.encode_key(key.encoding_key(), (key.encoding_key(), filter_name))
}
}
}
}
// Small hack to have the following macro invocation act as the appropriate
@@ -445,41 +452,71 @@ macro_rules! impl_state_store {
}
impl_state_store! {
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
async fn get_kv_data(
&self,
key: StateStoreDataKey<'_>,
) -> Result<Option<StateStoreDataValue>> {
let encoded_key = self.encode_kv_data_key(key);
let value = self
.inner
.transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readonly)?
.object_store(KEYS::KV)?
.get(&encoded_key)?
.await?
.map(|f| self.deserialize_event::<String>(f))
.transpose()?;
let value = match key {
StateStoreDataKey::SyncToken => value.map(StateStoreDataValue::SyncToken),
StateStoreDataKey::Filter(_) => value.map(StateStoreDataValue::Filter),
};
Ok(value)
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<()> {
let encoded_key = self.encode_kv_data_key(key);
let value = match key {
StateStoreDataKey::SyncToken => {
value.into_sync_token().expect("Session data not a sync token")
}
StateStoreDataKey::Filter(_) => {
value.into_filter().expect("Session data not a filter")
}
};
let tx = self
.inner
.transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readwrite)?;
let obj = tx.object_store(KEYS::KV)?;
obj.put_key_val(
&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)),
&self.serialize_event(&filter_id)?,
)?;
obj.put_key_val(&encoded_key, &self.serialize_event(&value)?)?;
tx.await.into_result()?;
Ok(())
}
async fn get_filter(&self, filter_name: &str) -> Result<Option<String>> {
self.inner
.transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readonly)?
.object_store(KEYS::KV)?
.get(&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
let encoded_key = self.encode_kv_data_key(key);
async fn get_sync_token(&self) -> Result<Option<String>> {
self.inner
.transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readonly)?
.object_store(KEYS::KV)?
.get(&self.encode_key(KEYS::SYNC_TOKEN, KEYS::SYNC_TOKEN))?
.await?
.map(|f| self.deserialize_event(f))
.transpose()
let tx = self
.inner
.transaction_on_one_with_mode(KEYS::KV, IdbTransactionMode::Readwrite)?;
let obj = tx.object_store(KEYS::KV)?;
obj.delete(&encoded_key)?;
tx.await.into_result()?;
Ok(())
}
async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
@@ -543,8 +580,10 @@ impl_state_store! {
self.inner.transaction_on_multi_with_mode(&stores, IdbTransactionMode::Readwrite)?;
if let Some(s) = &changes.sync_token {
tx.object_store(KEYS::KV)?
.put_key_val(&self.encode_key(KEYS::SYNC_TOKEN, KEYS::SYNC_TOKEN), &self.serialize_event(s)?)?;
tx.object_store(KEYS::KV)?.put_key_val(
&self.encode_kv_data_key(StateStoreDataKey::SyncToken),
&self.serialize_event(s)?,
)?;
}
if !changes.ambiguity_maps.is_empty() {

View File

@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use matrix_sdk_base::store::{Result as StoreResult, StoreError};
use matrix_sdk_base::{
store::{Result as StoreResult, StoreError},
StateStoreDataKey,
};
use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue};
use sled::{transaction::TransactionError, Batch, Transactional, Tree};
use tracing::debug;
@@ -167,13 +170,13 @@ impl SledStateStore {
let mut batch = sled::Batch::default();
// Sync token
let sync_token = session.get(keys::SYNC_TOKEN.encode())?;
let sync_token = session.get(StateStoreDataKey::SyncToken.encoding_key().encode())?;
if let Some(sync_token) = sync_token {
batch.insert(keys::SYNC_TOKEN.encode(), sync_token);
batch.insert(StateStoreDataKey::SyncToken.encoding_key().encode(), sync_token);
}
// Filters
let key = self.encode_key(keys::SESSION, keys::FILTER);
let key = self.encode_key(keys::SESSION, StateStoreDataKey::Filter("").encoding_key());
for res in session.scan_prefix(key) {
let (key, value) = res?;
batch.insert(key, value);
@@ -222,6 +225,7 @@ pub const V1_DB_STORES: &[&str] = &[
#[cfg(test)]
mod test {
use matrix_sdk_base::StateStoreDataKey;
use matrix_sdk_test::async_test;
use ruma::{
events::{AnySyncStateEvent, StateEventType},
@@ -391,13 +395,22 @@ mod test {
let session = store.inner.open_tree(old_keys::SESSION).unwrap();
let mut batch = sled::Batch::default();
batch.insert(keys::SYNC_TOKEN.encode(), store.serialize_value(&sync_token).unwrap());
batch.insert(
store.encode_key(keys::SESSION, (keys::FILTER, filter_1)),
StateStoreDataKey::SyncToken.encoding_key().encode(),
store.serialize_value(&sync_token).unwrap(),
);
batch.insert(
store.encode_key(
keys::SESSION,
(StateStoreDataKey::Filter("").encoding_key(), filter_1),
),
store.serialize_value(&filter_1_id).unwrap(),
);
batch.insert(
store.encode_key(keys::SESSION, (keys::FILTER, filter_2)),
store.encode_key(
keys::SESSION,
(StateStoreDataKey::Filter("").encoding_key(), filter_2),
),
store.serialize_value(&filter_2_id).unwrap(),
);
session.apply_batch(batch).unwrap();
@@ -412,13 +425,31 @@ mod test {
.build()
.unwrap();
let stored_sync_token = store.get_sync_token().await.unwrap().unwrap();
let stored_sync_token = store
.get_kv_data(StateStoreDataKey::SyncToken)
.await
.unwrap()
.unwrap()
.into_sync_token()
.unwrap();
assert_eq!(stored_sync_token, sync_token);
let stored_filter_1_id = store.get_filter(filter_1).await.unwrap().unwrap();
let stored_filter_1_id = store
.get_kv_data(StateStoreDataKey::Filter(filter_1))
.await
.unwrap()
.unwrap()
.into_filter()
.unwrap();
assert_eq!(stored_filter_1_id, filter_1_id);
let stored_filter_2_id = store.get_filter(filter_2).await.unwrap().unwrap();
let stored_filter_2_id = store
.get_kv_data(StateStoreDataKey::Filter(filter_2))
.await
.unwrap()
.unwrap()
.into_filter()
.unwrap();
assert_eq!(stored_filter_2_id, filter_2_id);
}
}

View File

@@ -26,7 +26,7 @@ use matrix_sdk_base::{
deserialized_responses::RawMemberEvent,
media::{MediaRequest, UniqueKey},
store::{Result as StoreResult, StateChanges, StateStore, StoreError},
MinimalStateEvent, RoomInfo,
MinimalStateEvent, RoomInfo, StateStoreDataKey, StateStoreDataValue,
};
use matrix_sdk_store_encryption::{Error as KeyEncryptionError, StoreCipher};
use ruma::{
@@ -104,9 +104,7 @@ impl From<SledStoreError> for StoreError {
mod keys {
// Static keys
pub const SYNC_TOKEN: &str = "sync_token";
pub const SESSION: &str = "session";
pub const FILTER: &str = "filter";
// Stores
pub const ACCOUNT_DATA: &str = "account-data";
@@ -411,23 +409,54 @@ impl SledStateStore {
}
}
pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> {
self.kv.insert(
self.encode_key(keys::SESSION, (keys::FILTER, filter_name)),
self.serialize_value(&filter_id)?,
)?;
fn encode_kv_data_key(&self, key: StateStoreDataKey<'_>) -> Vec<u8> {
match key {
StateStoreDataKey::SyncToken => key.encoding_key().encode(),
StateStoreDataKey::Filter(filter_name) => {
self.encode_key(keys::SESSION, (key.encoding_key(), filter_name))
}
}
}
async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
let encoded_key = self.encode_kv_data_key(key);
let value =
self.kv.get(encoded_key)?.map(|e| self.deserialize_value::<String>(&e)).transpose()?;
let value = match key {
StateStoreDataKey::SyncToken => value.map(StateStoreDataValue::SyncToken),
StateStoreDataKey::Filter(_) => value.map(StateStoreDataValue::Filter),
};
Ok(value)
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<()> {
let encoded_key = self.encode_kv_data_key(key);
let value = match key {
StateStoreDataKey::SyncToken => {
value.into_sync_token().expect("Session data not a sync token")
}
StateStoreDataKey::Filter(_) => value.into_filter().expect("Session data not a filter"),
};
self.kv.insert(encoded_key, self.serialize_value(&value)?)?;
Ok(())
}
pub async fn get_filter(&self, filter_name: &str) -> Result<Option<String>> {
self.kv
.get(self.encode_key(keys::SESSION, (keys::FILTER, filter_name)))?
.map(|f| self.deserialize_value(&f))
.transpose()
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
let encoded_key = self.encode_kv_data_key(key);
pub async fn get_sync_token(&self) -> Result<Option<String>> {
self.kv.get(keys::SYNC_TOKEN.encode())?.map(|t| self.deserialize_value(&t)).transpose()
self.kv.remove(encoded_key)?;
Ok(())
}
pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
@@ -792,7 +821,7 @@ impl SledStateStore {
.transaction(|(kv, account_data)| {
if let Some(s) = &changes.sync_token {
kv.insert(
keys::SYNC_TOKEN.encode(),
self.encode_kv_data_key(StateStoreDataKey::SyncToken),
self.serialize_value(s).map_err(ConflictableTransactionError::Abort)?,
)?;
}
@@ -1323,22 +1352,29 @@ impl SledStateStore {
impl StateStore for SledStateStore {
type Error = StoreError;
async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> {
self.save_filter(filter_name, filter_id).await.map_err(Into::into)
async fn get_kv_data(
&self,
key: StateStoreDataKey<'_>,
) -> StoreResult<Option<StateStoreDataValue>> {
self.get_kv_data(key).await.map_err(Into::into)
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> StoreResult<()> {
self.set_kv_data(key, value).await.map_err(Into::into)
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> StoreResult<()> {
self.remove_kv_data(key).await.map_err(Into::into)
}
async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> {
self.save_changes(changes).await.map_err(Into::into)
}
async fn get_filter(&self, filter_id: &str) -> StoreResult<Option<String>> {
self.get_filter(filter_id).await.map_err(Into::into)
}
async fn get_sync_token(&self) -> StoreResult<Option<String>> {
self.get_sync_token().await.map_err(Into::into)
}
async fn get_presence_event(
&self,
user_id: &UserId,