send queue: make SendHandle::abort/update more precise

Using `SendHandle::abort()` after the event has been sent would look
like a successful abort of the event, while it's not the case; this
fixes this by having the state store backends return whether they've
touched an entry in the database.
This commit is contained in:
Benjamin Bouvier
2024-07-01 15:06:22 +02:00
parent 16aa6df0b8
commit 2f125e97ee
6 changed files with 106 additions and 36 deletions

View File

@@ -890,7 +890,7 @@ impl StateStore for MemoryStore {
room_id: &RoomId,
transaction_id: &TransactionId,
content: SerializableEventContent,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
if let Some(entry) = self
.send_queue_events
.write()
@@ -902,15 +902,17 @@ impl StateStore for MemoryStore {
{
entry.event = content;
entry.is_wedged = false;
Ok(true)
} else {
Ok(false)
}
Ok(())
}
async fn remove_send_queue_event(
&self,
room_id: &RoomId,
transaction_id: &TransactionId,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
let mut q = self.send_queue_events.write().unwrap();
let entry = q.get_mut(room_id);
@@ -922,10 +924,11 @@ impl StateStore for MemoryStore {
if entry.is_empty() {
q.remove(room_id);
}
return Ok(true);
}
}
Ok(())
Ok(false)
}
async fn load_send_queue_events(

View File

@@ -418,20 +418,24 @@ pub trait StateStore: AsyncTraitDeps {
/// * `transaction_id` - The unique key identifying the event to be sent
/// (and its transaction).
/// * `content` - Serializable event content to replace the original one.
///
/// Returns true if an event has been updated, or false otherwise.
async fn update_send_queue_event(
&self,
room_id: &RoomId,
transaction_id: &TransactionId,
content: SerializableEventContent,
) -> Result<(), Self::Error>;
) -> Result<bool, Self::Error>;
/// Remove an event previously inserted with [`Self::save_send_queue_event`]
/// from the database, based on its transaction id.
///
/// Returns true if an event has been removed, or false otherwise.
async fn remove_send_queue_event(
&self,
room_id: &RoomId,
transaction_id: &TransactionId,
) -> Result<(), Self::Error>;
) -> Result<bool, Self::Error>;
/// Loads all the send queue events for the given room.
async fn load_send_queue_events(
@@ -687,7 +691,7 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
room_id: &RoomId,
transaction_id: &TransactionId,
content: SerializableEventContent,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
self.0.update_send_queue_event(room_id, transaction_id, content).await.map_err(Into::into)
}
@@ -695,7 +699,7 @@ impl<T: StateStore> StateStore for EraseStateStoreError<T> {
&self,
room_id: &RoomId,
transaction_id: &TransactionId,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
self.0.remove_send_queue_event(room_id, transaction_id).await.map_err(Into::into)
}

View File

@@ -1400,7 +1400,7 @@ impl_state_store!({
room_id: &RoomId,
transaction_id: &TransactionId,
content: SerializableEventContent,
) -> Result<()> {
) -> Result<bool> {
let encoded_key = self.encode_key(keys::ROOM_SEND_QUEUE, room_id);
let tx = self
@@ -1423,21 +1423,22 @@ impl_state_store!({
if let Some(entry) = prev.iter_mut().find(|entry| entry.transaction_id == transaction_id) {
entry.event = content;
entry.is_wedged = false;
// Save the new vector into db.
obj.put_key_val(&encoded_key, &self.serialize_value(&prev)?)?;
tx.await.into_result()?;
Ok(true)
} else {
Ok(false)
}
// Save the new vector into db.
obj.put_key_val(&encoded_key, &self.serialize_value(&prev)?)?;
tx.await.into_result()?;
Ok(())
}
async fn remove_send_queue_event(
&self,
room_id: &RoomId,
transaction_id: &TransactionId,
) -> Result<()> {
) -> Result<bool> {
let encoded_key = self.encode_key(keys::ROOM_SEND_QUEUE, room_id);
let tx = self
@@ -1459,12 +1460,13 @@ impl_state_store!({
} else {
obj.put_key_val(&encoded_key, &self.serialize_value(&prev)?)?;
}
tx.await.into_result()?;
return Ok(true);
}
}
tx.await.into_result()?;
Ok(())
Ok(false)
}
async fn load_send_queue_events(&self, room_id: &RoomId) -> Result<Vec<QueuedEvent>> {

View File

@@ -1709,7 +1709,7 @@ impl StateStore for SqliteStateStore {
room_id: &RoomId,
transaction_id: &TransactionId,
content: SerializableEventContent,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
let room_id = self.encode_key(keys::SEND_QUEUE, room_id);
let content = self.serialize_json(&content)?;
@@ -1717,35 +1717,38 @@ impl StateStore for SqliteStateStore {
// transaction id is neither encrypted or hashed.
let transaction_id = transaction_id.to_string();
self.acquire()
let num_updated = self.acquire()
.await?
.with_transaction(move |txn| {
txn.prepare_cached("UPDATE send_queue_events SET wedged = false, content = ? WHERE room_id = ? AND transaction_id = ?")?.execute((content, room_id, transaction_id))?;
Ok(())
txn.prepare_cached("UPDATE send_queue_events SET wedged = false, content = ? WHERE room_id = ? AND transaction_id = ?")?.execute((content, room_id, transaction_id))
})
.await
.await?;
Ok(num_updated > 0)
}
async fn remove_send_queue_event(
&self,
room_id: &RoomId,
transaction_id: &TransactionId,
) -> Result<(), Self::Error> {
) -> Result<bool, Self::Error> {
let room_id = self.encode_key(keys::SEND_QUEUE, room_id);
// See comment in `save_send_queue_event`.
let transaction_id = transaction_id.to_string();
self.acquire()
let num_deleted = self
.acquire()
.await?
.with_transaction(move |txn| {
txn.prepare_cached(
"DELETE FROM send_queue_events WHERE room_id = ? AND transaction_id = ?",
)?
.execute((room_id, transaction_id))?;
Ok(())
.execute((room_id, transaction_id))
})
.await
.await?;
Ok(num_deleted > 0)
}
async fn load_send_queue_events(

View File

@@ -674,13 +674,19 @@ impl QueueStorage {
) -> Result<(), RoomSendQueueStorageError> {
self.mark_as_not_being_sent(transaction_id).await;
Ok(self
let removed = self
.client
.get()
.ok_or(RoomSendQueueStorageError::ClientShuttingDown)?
.store()
.remove_send_queue_event(&self.room_id, transaction_id)
.await?)
.await?;
if !removed {
warn!(txn_id = %transaction_id, "event marked as sent was missing from storage");
}
Ok(())
}
/// Cancel a sending command for an event that has been sent with
@@ -699,14 +705,15 @@ impl QueueStorage {
return Ok(false);
}
self.client
let removed = self
.client
.get()
.ok_or(RoomSendQueueStorageError::ClientShuttingDown)?
.store()
.remove_send_queue_event(&self.room_id, transaction_id)
.await?;
Ok(true)
Ok(removed)
}
/// Replace an event that has been sent with
@@ -727,14 +734,15 @@ impl QueueStorage {
return Ok(false);
}
self.client
let edited = self
.client
.get()
.ok_or(RoomSendQueueStorageError::ClientShuttingDown)?
.store()
.update_send_queue_event(&self.room_id, transaction_id, serializable)
.await?;
Ok(true)
Ok(edited)
}
/// Returns a list of the local echoes, that is, all the events that we're

View File

@@ -1,4 +1,5 @@
use std::{
ops::Not as _,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex as StdMutex,
@@ -1018,6 +1019,55 @@ async fn test_abort_after_disable() {
assert!(errors.is_empty());
}
#[async_test]
async fn test_abort_or_edit_after_send() {
let (client, server) = logged_in_client_with_server().await;
// Mark the room as joined.
let room_id = room_id!("!a:b.c");
let room = mock_sync_with_new_room(
|builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id));
},
&client,
&server,
room_id,
)
.await;
// Start with an enabled sending queue.
client.send_queue().set_enabled(true).await;
let q = room.send_queue();
let (local_echoes, mut watch) = q.subscribe().await.unwrap();
assert!(local_echoes.is_empty());
assert!(watch.is_empty());
server.reset().await;
mock_encryption_state(&server, false).await;
mock_send_event(event_id!("$1")).mount(&server).await;
let handle = q.send(RoomMessageEventContent::text_plain("hey there").into()).await.unwrap();
// It is first seen as a local echo,
let (txn, _) = assert_update!(watch => local echo { body = "hey there" });
// Then sent.
assert_update!(watch => sent { txn = txn, });
// Editing shouldn't work anymore.
assert!(handle
.edit(RoomMessageEventContent::text_plain("i meant something completely different").into())
.await
.unwrap()
.not());
// Neither will aborting.
assert!(handle.abort().await.unwrap().not());
assert!(watch.is_empty());
}
#[async_test]
async fn test_unrecoverable_errors() {
let (client, server) = logged_in_client_with_server().await;