From a88d6b37dcbec9055bc66fcff089c768c764b1b3 Mon Sep 17 00:00:00 2001 From: Benjamin Bouvier Date: Thu, 14 Aug 2025 13:19:56 +0200 Subject: [PATCH] feat(event cache): also subscribe to a thread if we've sent a message into it --- crates/matrix-sdk/src/event_cache/mod.rs | 155 ++++++++++++++++-- .../tests/integration/send_queue.rs | 114 ++++++++++++- 2 files changed, 258 insertions(+), 11 deletions(-) diff --git a/crates/matrix-sdk/src/event_cache/mod.rs b/crates/matrix-sdk/src/event_cache/mod.rs index 7f2cf758f..123bcad64 100644 --- a/crates/matrix-sdk/src/event_cache/mod.rs +++ b/crates/matrix-sdk/src/event_cache/mod.rs @@ -28,7 +28,7 @@ #![forbid(missing_docs)] use std::{ - collections::BTreeMap, + collections::{BTreeMap, HashMap}, fmt, sync::{Arc, OnceLock}, }; @@ -44,6 +44,7 @@ use matrix_sdk_base::{ }, executor::AbortOnDrop, linked_chunk::{self, lazy_loader::LazyLoaderError, OwnedLinkedChunkId}, + store::SerializableEventContent, store_locks::LockStoreError, sync::RoomUpdates, timer, @@ -52,7 +53,11 @@ use matrix_sdk_common::executor::{spawn, JoinHandle}; #[cfg(feature = "experimental-search")] use matrix_sdk_search::error::IndexError; use room::RoomEventCacheState; -use ruma::{events::AnySyncEphemeralRoomEvent, serde::Raw, OwnedEventId, OwnedRoomId, RoomId}; +use ruma::{ + events::{room::encrypted, AnySyncEphemeralRoomEvent}, + serde::Raw, + OwnedEventId, OwnedRoomId, OwnedTransactionId, RoomId, +}; use tokio::{ select, sync::{ @@ -62,7 +67,11 @@ use tokio::{ }; use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument as _, Span}; -use crate::{client::WeakClient, Client}; +use crate::{ + client::WeakClient, + send_queue::{LocalEchoContent, RoomSendQueueUpdate, SendQueueUpdate}, + Client, +}; mod deduplicator; mod pagination; @@ -425,6 +434,7 @@ impl EventCache { self.inner.generic_update_sender.subscribe() } + #[instrument(skip(client, thread_subscriber_sender))] async fn handle_thread_subscriber_linked_chunk_update( client: &WeakClient, thread_subscriber_sender: &Sender<()>, @@ -508,7 +518,102 @@ impl EventCache { } } - return true; + true + } + + #[instrument(skip(client, thread_subscriber_sender))] + async fn handle_thread_subscriber_send_queue_update( + client: &WeakClient, + thread_subscriber_sender: &Sender<()>, + events_being_sent: &mut HashMap, + up: SendQueueUpdate, + ) -> bool { + let Some(client) = client.get() else { + // Client shutting down. + debug!("Client is shutting down, exiting thread subscriber task"); + return false; + }; + + let room_id = up.room_id; + let Some(room) = client.get_room(&room_id) else { + warn!(%room_id, "unknown room"); + return true; + }; + + let extract_thread_root = |serialized_event: SerializableEventContent| { + match serialized_event.deserialize() { + Ok(content) => { + if let Some(encrypted::Relation::Thread(thread)) = content.relation() { + return Some(thread.event_id); + } + } + Err(err) => { + warn!("error when deserializing content of a local echo: {err}"); + } + } + None + }; + + let (thread_root, subscribe_up_to) = match up.update { + RoomSendQueueUpdate::NewLocalEvent(local_echo) => { + match local_echo.content { + LocalEchoContent::Event { serialized_event, .. } => { + if let Some(thread_root) = extract_thread_root(serialized_event) { + events_being_sent.insert(local_echo.transaction_id, thread_root); + } + } + LocalEchoContent::React { .. } => { + // Nothing to do, reactions don't count as a thread + // subscription. + } + } + return true; + } + + RoomSendQueueUpdate::CancelledLocalEvent { transaction_id } => { + events_being_sent.remove(&transaction_id); + return true; + } + + RoomSendQueueUpdate::ReplacedLocalEvent { transaction_id, new_content } => { + if let Some(thread_root) = extract_thread_root(new_content) { + events_being_sent.insert(transaction_id, thread_root); + } else { + // It could be that the event isn't part of a thread anymore; handle that by + // removing the pending transaction id. + events_being_sent.remove(&transaction_id); + } + return true; + } + + RoomSendQueueUpdate::SentEvent { transaction_id, event_id } => { + if let Some(thread_root) = events_being_sent.remove(&transaction_id) { + (thread_root, event_id) + } else { + // We don't know about the event that has been sent, so ignore it. + trace!(%transaction_id, "received a sent event that we didn't know about, ignoring"); + return true; + } + } + + RoomSendQueueUpdate::SendError { .. } + | RoomSendQueueUpdate::RetryEvent { .. } + | RoomSendQueueUpdate::MediaUpload { .. } => { + // Nothing to do for these bad boys. + return true; + } + }; + + // And if we've found such a mention, subscribe to the thread up to this event. + trace!(thread = %thread_root, up_to = %subscribe_up_to, "found a new thread to subscribe to"); + if let Err(err) = room.subscribe_thread_if_needed(&thread_root, Some(subscribe_up_to)).await + { + warn!(%err, "Failed to subscribe to thread"); + } else { + let _ = thread_subscriber_sender.send(()); + } + + true } #[instrument(skip_all)] @@ -517,16 +622,46 @@ impl EventCache { linked_chunk_update_sender: Sender, thread_subscriber_sender: Sender<()>, ) { - if client.get().map_or(false, |client| !client.enabled_thread_subscriptions()) { - trace!("Not spawning the thread subscriber task, because the client is shutting down or is not interested in those"); - return; - } + let mut send_q_rx = if let Some(client) = client.get() { + if !client.enabled_thread_subscriptions() { + trace!("Thread subscriptions are not enabled, not spawning thread subscriber task"); + return; + } - let mut rx = linked_chunk_update_sender.subscribe(); + client.send_queue().subscribe() + } else { + trace!("Client is shutting down, not spawning thread subscriber task"); + return; + }; + + let mut linked_chunk_rx = linked_chunk_update_sender.subscribe(); + + // A mapping of local echoes (events being sent), to their thread root, if + // they're in an in-thread reply. + // + // Entirely managed by `handle_thread_subscriber_send_queue_update`. + let mut events_being_sent = HashMap::new(); loop { select! { - res = rx.recv() => { + res = send_q_rx.recv() => { + match res { + Ok(up) => { + if !Self::handle_thread_subscriber_send_queue_update(&client, &thread_subscriber_sender, &mut events_being_sent, up).await { + break; + } + } + Err(RecvError::Closed) => { + debug!("Linked chunk update channel has been closed, exiting thread subscriber task"); + break; + } + Err(RecvError::Lagged(num_skipped)) => { + warn!(num_skipped, "Lagged behind linked chunk updates"); + } + } + } + + res = linked_chunk_rx.recv() => { match res { Ok(up) => { if !Self::handle_thread_subscriber_linked_chunk_update(&client, &thread_subscriber_sender, up).await { diff --git a/crates/matrix-sdk/tests/integration/send_queue.rs b/crates/matrix-sdk/tests/integration/send_queue.rs index db6578005..f57628f08 100644 --- a/crates/matrix-sdk/tests/integration/send_queue.rs +++ b/crates/matrix-sdk/tests/integration/send_queue.rs @@ -14,7 +14,7 @@ use matrix_sdk::{ RoomSendQueueStorageError, RoomSendQueueUpdate, SendHandle, SendQueueUpdate, }, test_utils::mocks::{MatrixMock, MatrixMockServer}, - Client, MemoryStore, + Client, MemoryStore, ThreadingSupport, }; use matrix_sdk_test::{ async_test, event_factory::EventFactory, InvitedRoomBuilder, KnockedRoomBuilder, @@ -29,6 +29,7 @@ use ruma::{ NewUnstablePollStartEventContent, UnstablePollAnswer, UnstablePollAnswers, UnstablePollStartContentBlock, UnstablePollStartEventContent, }, + relation::Thread, room::{ message::{ ImageMessageEventContent, MessageType, Relation, ReplyWithinThread, @@ -3712,3 +3713,114 @@ async fn test_update_caption_while_sending_media_event() { // That's all, folks! assert!(watch.is_empty()); } + +#[async_test] +async fn test_sending_reply_in_thread_auto_subscribe() { + let server = MatrixMockServer::new().await; + + // Assuming a client that's interested in thread subscriptions, + let client = server + .client_builder() + .on_builder(|builder| { + builder.with_threading_support(ThreadingSupport::Enabled { with_subscriptions: true }) + }) + .build() + .await; + + client.event_cache().subscribe().unwrap(); + + let room_id = room_id!("!a:b.c"); + let room = server.sync_joined_room(&client, room_id).await; + + server.mock_room_state_encryption().plain().mount().await; + + // When I send a message to a thread, + let thread_root = event_id!("$thread"); + + let mut content = RoomMessageEventContent::text_plain("hello world"); + content.relates_to = + Some(Relation::Thread(Thread::plain(thread_root.to_owned(), thread_root.to_owned()))); + + server.mock_room_send().ok(event_id!("$reply")).mock_once().mount().await; + + server + .mock_put_thread_subscription() + .match_room_id(room_id.to_owned()) + .match_thread_id(thread_root.to_owned()) + .ok() + .mock_once() + .mount() + .await; + + let (_, mut stream) = room.send_queue().subscribe().await.unwrap(); + room.send_queue().send(content.into()).await.unwrap(); + + // Let the send queue process the event. + assert_let_timeout!(Ok(RoomSendQueueUpdate::NewLocalEvent(..)) = stream.recv()); + assert_let_timeout!(Ok(RoomSendQueueUpdate::SentEvent { .. }) = stream.recv()); + assert_let_timeout!(Ok(()) = thread_subscriber_updates.recv()); + + // Check the endpoints have been correctly called. + server.server().reset().await; + + // Now, if I send a message in a thread I've already subscribed to, in automatic + // mode, this promotes the subscription to manual. + + // Subscribed, automatically. + server + .mock_get_thread_subscription() + .match_room_id(room_id.to_owned()) + .match_thread_id(thread_root.to_owned()) + .ok(true) + .mount() + .await; + + // I'll get one subscription. + server + .mock_put_thread_subscription() + .match_room_id(room_id.to_owned()) + .match_thread_id(thread_root.to_owned()) + .ok() + .mock_once() + .mount() + .await; + + server.mock_room_send().ok(event_id!("$reply")).mock_once().mount().await; + + let mut content = RoomMessageEventContent::text_plain("hello world"); + content.relates_to = + Some(Relation::Thread(Thread::plain(thread_root.to_owned(), thread_root.to_owned()))); + room.send_queue().send(content.into()).await.unwrap(); + + // Let the send queue process the event. + assert_let!(RoomSendQueueUpdate::NewLocalEvent(..) = stream.recv().await.unwrap()); + assert_let!(RoomSendQueueUpdate::SentEvent { .. } = stream.recv().await.unwrap()); + + // Check the endpoints have been correctly called. + server.server().reset().await; + + // Subscribed, but manually. + server + .mock_get_thread_subscription() + .match_room_id(room_id.to_owned()) + .match_thread_id(thread_root.to_owned()) + .ok(false) + .mount() + .await; + + // I'll get zero subscription. + server.mock_put_thread_subscription().ok().expect(0).mount().await; + + server.mock_room_send().ok(event_id!("$reply")).mock_once().mount().await; + + let mut content = RoomMessageEventContent::text_plain("hello world"); + content.relates_to = + Some(Relation::Thread(Thread::plain(thread_root.to_owned(), thread_root.to_owned()))); + room.send_queue().send(content.into()).await.unwrap(); + + // Let the send queue process the event. + assert_let!(RoomSendQueueUpdate::NewLocalEvent(..) = stream.recv().await.unwrap()); + assert_let!(RoomSendQueueUpdate::SentEvent { .. } = stream.recv().await.unwrap()); + + sleep(Duration::from_millis(100)).await; +}