From 7eefcfb72e91c809cacdfdd0b8d1cfacd9798d36 Mon Sep 17 00:00:00 2001 From: Michael Goldenberg Date: Tue, 14 Apr 2026 13:09:25 -0400 Subject: [PATCH] feat(base): add state store wrapper providing synchronized calls to save_changes Signed-off-by: Michael Goldenberg --- crates/matrix-sdk-base/src/store/traits.rs | 360 +++++++++++++++++++++ 1 file changed, 360 insertions(+) diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs index 80497bd5b..336ae6ad5 100644 --- a/crates/matrix-sdk-base/src/store/traits.rs +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -46,6 +46,8 @@ use ruma::{ serde::Raw, }; use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio::sync::{Mutex, MutexGuard}; use super::{ ChildTransactionId, DependentQueuedRequest, DependentQueuedRequestKind, QueueWedgeError, @@ -851,6 +853,364 @@ impl StateStore for EraseStateStoreError { } } +/// A wrapper around a [`StateStore`] that supports synchronizing calls to +/// [`StateStore::save_changes`]. +#[allow(unused)] +#[derive(Debug, Clone)] +pub struct SaveLockedStateStore> { + store: T, + lock: Arc>, +} + +/// An error type that represents a scenario where a [`MutexGuard`] provided to +/// a function does not reference the underlying [`Mutex`] in the enclosing +/// [`SaveLockedStateStore`]. +#[allow(unused)] +#[derive(Debug, Error)] +#[error("a mutex guard was provided, but it does not reference the correct mutex")] +pub struct IncorrectMutexGuardError; + +impl From for StoreError { + fn from(value: IncorrectMutexGuardError) -> Self { + Self::backend(value) + } +} + +impl SaveLockedStateStore { + /// Creates a new [`SaveLockedStateStore`] with the provided store. + #[allow(unused)] + pub fn new(store: T) -> Self { + Self { store, lock: Arc::new(Mutex::new(())) } + } + + /// Returns a reference to the underlying [`Mutex`] used to synchronize + /// calls to [`StateStore::save_changes`]. + #[allow(unused)] + pub fn lock(&self) -> &Mutex<()> { + self.lock.as_ref() + } +} + +impl SaveLockedStateStore { + /// Provides a means of calling [`StateStore::save_changes`] when the caller + /// has already acquired the underlying [`Mutex`]. Returns an error if + /// the [`MutexGuard`] provided does not reference the underlying + /// [`Mutex`]. + #[allow(unused)] + pub async fn save_changes_with_guard( + &self, + guard: &MutexGuard<'_, ()>, + changes: &StateChanges, + ) -> Result<(), StoreError> { + if !std::ptr::eq(MutexGuard::mutex(guard), self.lock()) { + Err(IncorrectMutexGuardError.into()) + } else { + self.store.save_changes(changes).await.map_err(Into::into) + } + } +} + +#[cfg_attr(target_family = "wasm", async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait)] +impl StateStore for SaveLockedStateStore { + type Error = T::Error; + + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result, Self::Error> { + self.store.get_kv_data(key).await + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<(), Self::Error> { + self.store.set_kv_data(key, value).await + } + + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> { + self.store.remove_kv_data(key).await + } + + async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> { + let _guard = self.lock.lock().await; + self.store.save_changes(changes).await + } + + async fn get_presence_event( + &self, + user_id: &UserId, + ) -> Result>, Self::Error> { + self.store.get_presence_event(user_id).await + } + + async fn get_presence_events( + &self, + user_ids: &[OwnedUserId], + ) -> Result>, Self::Error> { + self.store.get_presence_events(user_ids).await + } + + async fn get_state_event( + &self, + room_id: &RoomId, + event_type: StateEventType, + state_key: &str, + ) -> Result, Self::Error> { + self.store.get_state_event(room_id, event_type, state_key).await + } + + async fn get_state_events( + &self, + room_id: &RoomId, + event_type: StateEventType, + ) -> Result, Self::Error> { + self.store.get_state_events(room_id, event_type).await + } + + async fn get_state_events_for_keys( + &self, + room_id: &RoomId, + event_type: StateEventType, + state_keys: &[&str], + ) -> Result, Self::Error> { + self.store.get_state_events_for_keys(room_id, event_type, state_keys).await + } + + async fn get_profile( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result, Self::Error> { + self.store.get_profile(room_id, user_id).await + } + + async fn get_profiles<'a>( + &self, + room_id: &RoomId, + user_ids: &'a [OwnedUserId], + ) -> Result, Self::Error> { + self.store.get_profiles(room_id, user_ids).await + } + + async fn get_user_ids( + &self, + room_id: &RoomId, + memberships: RoomMemberships, + ) -> Result, Self::Error> { + self.store.get_user_ids(room_id, memberships).await + } + + async fn get_room_infos( + &self, + room_load_settings: &RoomLoadSettings, + ) -> Result, Self::Error> { + self.store.get_room_infos(room_load_settings).await + } + + async fn get_users_with_display_name( + &self, + room_id: &RoomId, + display_name: &DisplayName, + ) -> Result, Self::Error> { + self.store.get_users_with_display_name(room_id, display_name).await + } + + async fn get_users_with_display_names<'a>( + &self, + room_id: &RoomId, + display_names: &'a [DisplayName], + ) -> Result>, Self::Error> { + self.store.get_users_with_display_names(room_id, display_names).await + } + + async fn get_account_data_event( + &self, + event_type: GlobalAccountDataEventType, + ) -> Result>, Self::Error> { + self.store.get_account_data_event(event_type).await + } + + async fn get_room_account_data_event( + &self, + room_id: &RoomId, + event_type: RoomAccountDataEventType, + ) -> Result>, Self::Error> { + self.store.get_room_account_data_event(room_id, event_type).await + } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + user_id: &UserId, + ) -> Result, Self::Error> { + self.store.get_user_room_receipt_event(room_id, receipt_type, thread, user_id).await + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: &EventId, + ) -> Result, Self::Error> { + self.store.get_event_room_receipt_events(room_id, receipt_type, thread, event_id).await + } + + async fn get_custom_value(&self, key: &[u8]) -> Result>, Self::Error> { + self.store.get_custom_value(key).await + } + + async fn set_custom_value( + &self, + key: &[u8], + value: Vec, + ) -> Result>, Self::Error> { + self.store.set_custom_value(key, value).await + } + + async fn remove_custom_value(&self, key: &[u8]) -> Result>, Self::Error> { + self.store.remove_custom_value(key).await + } + + async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> { + self.store.remove_room(room_id).await + } + + async fn save_send_queue_request( + &self, + room_id: &RoomId, + transaction_id: OwnedTransactionId, + created_at: MilliSecondsSinceUnixEpoch, + request: QueuedRequestKind, + priority: usize, + ) -> Result<(), Self::Error> { + self.store + .save_send_queue_request(room_id, transaction_id, created_at, request, priority) + .await + } + + async fn update_send_queue_request( + &self, + room_id: &RoomId, + transaction_id: &TransactionId, + content: QueuedRequestKind, + ) -> Result { + self.store.update_send_queue_request(room_id, transaction_id, content).await + } + + async fn remove_send_queue_request( + &self, + room_id: &RoomId, + transaction_id: &TransactionId, + ) -> Result { + self.store.remove_send_queue_request(room_id, transaction_id).await + } + + async fn load_send_queue_requests( + &self, + room_id: &RoomId, + ) -> Result, Self::Error> { + self.store.load_send_queue_requests(room_id).await + } + + async fn update_send_queue_request_status( + &self, + room_id: &RoomId, + transaction_id: &TransactionId, + error: Option, + ) -> Result<(), Self::Error> { + self.store.update_send_queue_request_status(room_id, transaction_id, error).await + } + + async fn load_rooms_with_unsent_requests(&self) -> Result, Self::Error> { + self.store.load_rooms_with_unsent_requests().await + } + + async fn save_dependent_queued_request( + &self, + room_id: &RoomId, + parent_txn_id: &TransactionId, + own_txn_id: ChildTransactionId, + created_at: MilliSecondsSinceUnixEpoch, + content: DependentQueuedRequestKind, + ) -> Result<(), Self::Error> { + self.store + .save_dependent_queued_request(room_id, parent_txn_id, own_txn_id, created_at, content) + .await + } + + async fn mark_dependent_queued_requests_as_ready( + &self, + room_id: &RoomId, + parent_txn_id: &TransactionId, + sent_parent_key: SentRequestKey, + ) -> Result { + self.store + .mark_dependent_queued_requests_as_ready(room_id, parent_txn_id, sent_parent_key) + .await + } + + async fn update_dependent_queued_request( + &self, + room_id: &RoomId, + own_transaction_id: &ChildTransactionId, + new_content: DependentQueuedRequestKind, + ) -> Result { + self.store.update_dependent_queued_request(room_id, own_transaction_id, new_content).await + } + + async fn remove_dependent_queued_request( + &self, + room: &RoomId, + own_txn_id: &ChildTransactionId, + ) -> Result { + self.store.remove_dependent_queued_request(room, own_txn_id).await + } + + async fn load_dependent_queued_requests( + &self, + room: &RoomId, + ) -> Result, Self::Error> { + self.store.load_dependent_queued_requests(room).await + } + + async fn upsert_thread_subscriptions( + &self, + updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>, + ) -> Result<(), Self::Error> { + self.store.upsert_thread_subscriptions(updates).await + } + + async fn remove_thread_subscription( + &self, + room: &RoomId, + thread_id: &EventId, + ) -> Result<(), Self::Error> { + self.store.remove_thread_subscription(room, thread_id).await + } + + async fn load_thread_subscription( + &self, + room: &RoomId, + thread_id: &EventId, + ) -> Result, Self::Error> { + self.store.load_thread_subscription(room, thread_id).await + } + + async fn optimize(&self) -> Result<(), Self::Error> { + self.store.optimize().await + } + + async fn get_size(&self) -> Result, Self::Error> { + self.store.get_size().await + } +} + /// Convenience functionality for state stores. #[cfg_attr(target_family = "wasm", async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait)]