mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-06-10 17:34:20 -04:00
feat(base): add state store wrapper providing synchronized calls to save_changes
Signed-off-by: Michael Goldenberg <m@mgoldenberg.net>
This commit is contained in:
committed by
Damir Jelić
parent
d0cb6648ec
commit
7eefcfb72e
@@ -46,6 +46,8 @@ use ruma::{
|
||||
serde::Raw,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
|
||||
use super::{
|
||||
ChildTransactionId, DependentQueuedRequest, DependentQueuedRequestKind, QueueWedgeError,
|
||||
@@ -851,6 +853,364 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper around a [`StateStore`] that supports synchronizing calls to
|
||||
/// [`StateStore::save_changes`].
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SaveLockedStateStore<T = Arc<DynStateStore>> {
|
||||
store: T,
|
||||
lock: Arc<Mutex<()>>,
|
||||
}
|
||||
|
||||
/// An error type that represents a scenario where a [`MutexGuard`] provided to
|
||||
/// a function does not reference the underlying [`Mutex`] in the enclosing
|
||||
/// [`SaveLockedStateStore`].
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Error)]
|
||||
#[error("a mutex guard was provided, but it does not reference the correct mutex")]
|
||||
pub struct IncorrectMutexGuardError;
|
||||
|
||||
impl From<IncorrectMutexGuardError> for StoreError {
|
||||
fn from(value: IncorrectMutexGuardError) -> Self {
|
||||
Self::backend(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> SaveLockedStateStore<T> {
|
||||
/// Creates a new [`SaveLockedStateStore`] with the provided store.
|
||||
#[allow(unused)]
|
||||
pub fn new(store: T) -> Self {
|
||||
Self { store, lock: Arc::new(Mutex::new(())) }
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying [`Mutex`] used to synchronize
|
||||
/// calls to [`StateStore::save_changes`].
|
||||
#[allow(unused)]
|
||||
pub fn lock(&self) -> &Mutex<()> {
|
||||
self.lock.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: StateStore> SaveLockedStateStore<T> {
|
||||
/// Provides a means of calling [`StateStore::save_changes`] when the caller
|
||||
/// has already acquired the underlying [`Mutex`]. Returns an error if
|
||||
/// the [`MutexGuard`] provided does not reference the underlying
|
||||
/// [`Mutex`].
|
||||
#[allow(unused)]
|
||||
pub async fn save_changes_with_guard(
|
||||
&self,
|
||||
guard: &MutexGuard<'_, ()>,
|
||||
changes: &StateChanges,
|
||||
) -> Result<(), StoreError> {
|
||||
if !std::ptr::eq(MutexGuard::mutex(guard), self.lock()) {
|
||||
Err(IncorrectMutexGuardError.into())
|
||||
} else {
|
||||
self.store.save_changes(changes).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_family = "wasm"), async_trait)]
|
||||
impl<T: StateStore> StateStore for SaveLockedStateStore<T> {
|
||||
type Error = T::Error;
|
||||
|
||||
async fn get_kv_data(
|
||||
&self,
|
||||
key: StateStoreDataKey<'_>,
|
||||
) -> Result<Option<StateStoreDataValue>, Self::Error> {
|
||||
self.store.get_kv_data(key).await
|
||||
}
|
||||
|
||||
async fn set_kv_data(
|
||||
&self,
|
||||
key: StateStoreDataKey<'_>,
|
||||
value: StateStoreDataValue,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.store.set_kv_data(key, value).await
|
||||
}
|
||||
|
||||
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
|
||||
self.store.remove_kv_data(key).await
|
||||
}
|
||||
|
||||
async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
|
||||
let _guard = self.lock.lock().await;
|
||||
self.store.save_changes(changes).await
|
||||
}
|
||||
|
||||
async fn get_presence_event(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
|
||||
self.store.get_presence_event(user_id).await
|
||||
}
|
||||
|
||||
async fn get_presence_events(
|
||||
&self,
|
||||
user_ids: &[OwnedUserId],
|
||||
) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
|
||||
self.store.get_presence_events(user_ids).await
|
||||
}
|
||||
|
||||
async fn get_state_event(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
|
||||
self.store.get_state_event(room_id, event_type, state_key).await
|
||||
}
|
||||
|
||||
async fn get_state_events(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: StateEventType,
|
||||
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
|
||||
self.store.get_state_events(room_id, event_type).await
|
||||
}
|
||||
|
||||
async fn get_state_events_for_keys(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: StateEventType,
|
||||
state_keys: &[&str],
|
||||
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
|
||||
self.store.get_state_events_for_keys(room_id, event_type, state_keys).await
|
||||
}
|
||||
|
||||
async fn get_profile(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
|
||||
self.store.get_profile(room_id, user_id).await
|
||||
}
|
||||
|
||||
async fn get_profiles<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_ids: &'a [OwnedUserId],
|
||||
) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
|
||||
self.store.get_profiles(room_id, user_ids).await
|
||||
}
|
||||
|
||||
async fn get_user_ids(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
memberships: RoomMemberships,
|
||||
) -> Result<Vec<OwnedUserId>, Self::Error> {
|
||||
self.store.get_user_ids(room_id, memberships).await
|
||||
}
|
||||
|
||||
async fn get_room_infos(
|
||||
&self,
|
||||
room_load_settings: &RoomLoadSettings,
|
||||
) -> Result<Vec<RoomInfo>, Self::Error> {
|
||||
self.store.get_room_infos(room_load_settings).await
|
||||
}
|
||||
|
||||
async fn get_users_with_display_name(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_name: &DisplayName,
|
||||
) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
|
||||
self.store.get_users_with_display_name(room_id, display_name).await
|
||||
}
|
||||
|
||||
async fn get_users_with_display_names<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
display_names: &'a [DisplayName],
|
||||
) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>, Self::Error> {
|
||||
self.store.get_users_with_display_names(room_id, display_names).await
|
||||
}
|
||||
|
||||
async fn get_account_data_event(
|
||||
&self,
|
||||
event_type: GlobalAccountDataEventType,
|
||||
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
|
||||
self.store.get_account_data_event(event_type).await
|
||||
}
|
||||
|
||||
async fn get_room_account_data_event(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
|
||||
self.store.get_room_account_data_event(room_id, event_type).await
|
||||
}
|
||||
|
||||
async fn get_user_room_receipt_event(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
receipt_type: ReceiptType,
|
||||
thread: ReceiptThread,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
|
||||
self.store.get_user_room_receipt_event(room_id, receipt_type, thread, user_id).await
|
||||
}
|
||||
|
||||
async fn get_event_room_receipt_events(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
receipt_type: ReceiptType,
|
||||
thread: ReceiptThread,
|
||||
event_id: &EventId,
|
||||
) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
|
||||
self.store.get_event_room_receipt_events(room_id, receipt_type, thread, event_id).await
|
||||
}
|
||||
|
||||
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
|
||||
self.store.get_custom_value(key).await
|
||||
}
|
||||
|
||||
async fn set_custom_value(
|
||||
&self,
|
||||
key: &[u8],
|
||||
value: Vec<u8>,
|
||||
) -> Result<Option<Vec<u8>>, Self::Error> {
|
||||
self.store.set_custom_value(key, value).await
|
||||
}
|
||||
|
||||
async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
|
||||
self.store.remove_custom_value(key).await
|
||||
}
|
||||
|
||||
async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
|
||||
self.store.remove_room(room_id).await
|
||||
}
|
||||
|
||||
async fn save_send_queue_request(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
transaction_id: OwnedTransactionId,
|
||||
created_at: MilliSecondsSinceUnixEpoch,
|
||||
request: QueuedRequestKind,
|
||||
priority: usize,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.store
|
||||
.save_send_queue_request(room_id, transaction_id, created_at, request, priority)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn update_send_queue_request(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
transaction_id: &TransactionId,
|
||||
content: QueuedRequestKind,
|
||||
) -> Result<bool, Self::Error> {
|
||||
self.store.update_send_queue_request(room_id, transaction_id, content).await
|
||||
}
|
||||
|
||||
async fn remove_send_queue_request(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
transaction_id: &TransactionId,
|
||||
) -> Result<bool, Self::Error> {
|
||||
self.store.remove_send_queue_request(room_id, transaction_id).await
|
||||
}
|
||||
|
||||
async fn load_send_queue_requests(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Vec<QueuedRequest>, Self::Error> {
|
||||
self.store.load_send_queue_requests(room_id).await
|
||||
}
|
||||
|
||||
async fn update_send_queue_request_status(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
transaction_id: &TransactionId,
|
||||
error: Option<QueueWedgeError>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.store.update_send_queue_request_status(room_id, transaction_id, error).await
|
||||
}
|
||||
|
||||
async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
|
||||
self.store.load_rooms_with_unsent_requests().await
|
||||
}
|
||||
|
||||
async fn save_dependent_queued_request(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
parent_txn_id: &TransactionId,
|
||||
own_txn_id: ChildTransactionId,
|
||||
created_at: MilliSecondsSinceUnixEpoch,
|
||||
content: DependentQueuedRequestKind,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.store
|
||||
.save_dependent_queued_request(room_id, parent_txn_id, own_txn_id, created_at, content)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn mark_dependent_queued_requests_as_ready(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
parent_txn_id: &TransactionId,
|
||||
sent_parent_key: SentRequestKey,
|
||||
) -> Result<usize, Self::Error> {
|
||||
self.store
|
||||
.mark_dependent_queued_requests_as_ready(room_id, parent_txn_id, sent_parent_key)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn update_dependent_queued_request(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
own_transaction_id: &ChildTransactionId,
|
||||
new_content: DependentQueuedRequestKind,
|
||||
) -> Result<bool, Self::Error> {
|
||||
self.store.update_dependent_queued_request(room_id, own_transaction_id, new_content).await
|
||||
}
|
||||
|
||||
async fn remove_dependent_queued_request(
|
||||
&self,
|
||||
room: &RoomId,
|
||||
own_txn_id: &ChildTransactionId,
|
||||
) -> Result<bool, Self::Error> {
|
||||
self.store.remove_dependent_queued_request(room, own_txn_id).await
|
||||
}
|
||||
|
||||
async fn load_dependent_queued_requests(
|
||||
&self,
|
||||
room: &RoomId,
|
||||
) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
|
||||
self.store.load_dependent_queued_requests(room).await
|
||||
}
|
||||
|
||||
async fn upsert_thread_subscriptions(
|
||||
&self,
|
||||
updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.store.upsert_thread_subscriptions(updates).await
|
||||
}
|
||||
|
||||
async fn remove_thread_subscription(
|
||||
&self,
|
||||
room: &RoomId,
|
||||
thread_id: &EventId,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.store.remove_thread_subscription(room, thread_id).await
|
||||
}
|
||||
|
||||
async fn load_thread_subscription(
|
||||
&self,
|
||||
room: &RoomId,
|
||||
thread_id: &EventId,
|
||||
) -> Result<Option<StoredThreadSubscription>, Self::Error> {
|
||||
self.store.load_thread_subscription(room, thread_id).await
|
||||
}
|
||||
|
||||
async fn optimize(&self) -> Result<(), Self::Error> {
|
||||
self.store.optimize().await
|
||||
}
|
||||
|
||||
async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
|
||||
self.store.get_size().await
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience functionality for state stores.
|
||||
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
|
||||
#[cfg_attr(not(target_family = "wasm"), async_trait)]
|
||||
|
||||
Reference in New Issue
Block a user