From 071f982eed94465f7058bbb1afed973cf669c847 Mon Sep 17 00:00:00 2001 From: Benjamin Bouvier Date: Thu, 5 Feb 2026 16:11:19 +0100 Subject: [PATCH] fix(event cache): don't include thread responses in the pinned event cache There were two issues: - first, `load_or_fetch_event_with_relations()` allowed to pass a filter, but the filter wasn't taken into account when fetching relations from the network. This would cause the initial load of pinned events to also include thread responses, which we don't want. - similarly, when adding related events from sync, we'd only look if an event had a `m.relates_to` field; but it could be a thread response being added in live. The two issues are fixed similarly, by using a new `extract_relation` serde helper that gives both the related_to event and the relation type. That way, we can apply a manual filter in `load_or_fetch_event_with_relations` after fetching relations from network, and we can filter out live events based on the relation type. --- crates/matrix-sdk-common/src/serde_helpers.rs | 45 +++++++-- .../src/event_cache/room/pinned_events.rs | 27 ++---- crates/matrix-sdk/src/room/mod.rs | 47 ++++++++- .../tests/integration/room/pinned_events.rs | 96 ++++++++++++++++++- 4 files changed, 186 insertions(+), 29 deletions(-) diff --git a/crates/matrix-sdk-common/src/serde_helpers.rs b/crates/matrix-sdk-common/src/serde_helpers.rs index 7d3416804..f6e37fc12 100644 --- a/crates/matrix-sdk-common/src/serde_helpers.rs +++ b/crates/matrix-sdk-common/src/serde_helpers.rs @@ -27,12 +27,24 @@ use serde::Deserialize; use crate::deserialized_responses::{ThreadSummary, ThreadSummaryStatus}; -#[derive(Deserialize)] -enum RelationsType { +/// The type of relation an event has to another one, if any. +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq)] +pub enum RelationsType { + /// The event is part of a thread, and the related event is the thread root. #[serde(rename = "m.thread")] Thread, + /// The event is an edit of another event, and the related event is the one + /// being edited. #[serde(rename = "m.replace")] Edit, + /// The event is an annotation of (reaction to) another event, and the + /// related event is the one being annotated. + #[serde(rename = "m.annotation")] + Annotation, + /// The event is referencing another event, and the related event is the one + /// being referenced. + #[serde(rename = "m.reference")] + Reference, } #[derive(Deserialize)] @@ -63,7 +75,7 @@ pub fn extract_thread_root_from_content( let relates_to = content.deserialize_as_unchecked::().ok()?.relates_to?; match relates_to.rel_type { RelationsType::Thread => relates_to.event_id, - RelationsType::Edit => None, + RelationsType::Edit | RelationsType::Reference | RelationsType::Annotation => None, } } @@ -91,10 +103,19 @@ pub fn extract_edit_target(event: &Raw) -> Option("content").ok().flatten()?.relates_to?; match relates_to.rel_type { RelationsType::Edit => relates_to.event_id, - RelationsType::Thread => None, + RelationsType::Thread | RelationsType::Reference | RelationsType::Annotation => None, } } +/// Try to extract the type and target of a relation, from a raw timeline event, +/// if provided. +pub fn extract_relation( + event: &Raw, +) -> Option<(RelationsType, OwnedEventId)> { + let relates_to = event.get_field::("content").ok().flatten()?.relates_to?; + Some((relates_to.rel_type, relates_to.event_id?)) +} + #[allow(missing_debug_implementations)] #[derive(Deserialize)] struct Relations { @@ -154,14 +175,17 @@ pub fn extract_timestamp( #[cfg(test)] mod tests { use assert_matches::assert_matches; - use ruma::{UInt, event_id}; + use ruma::{UInt, event_id, owned_event_id}; use serde_json::json; use super::{ MilliSecondsSinceUnixEpoch, Raw, extract_bundled_thread_summary, extract_thread_root, extract_timestamp, }; - use crate::deserialized_responses::{ThreadSummary, ThreadSummaryStatus}; + use crate::{ + deserialized_responses::{ThreadSummary, ThreadSummaryStatus}, + serde_helpers::{RelationsType, extract_relation}, + }; #[test] fn test_extract_thread_root() { @@ -188,6 +212,8 @@ mod tests { let observed_thread_root = extract_thread_root(&event); assert_eq!(observed_thread_root.as_deref(), Some(thread_root)); + let observed_relation = extract_relation(&event).unwrap(); + assert_eq!(observed_relation, (RelationsType::Thread, thread_root.to_owned())); // If the event doesn't have a content for some reason (redacted), it returns // None. @@ -202,6 +228,7 @@ mod tests { let observed_thread_root = extract_thread_root(&event); assert_matches!(observed_thread_root, None); + assert_matches!(extract_relation(&event), None); // If the event has a content but with no `m.relates_to` field, it returns None. let event = Raw::new(&json!({ @@ -218,6 +245,7 @@ mod tests { let observed_thread_root = extract_thread_root(&event); assert_matches!(observed_thread_root, None); + assert_matches!(extract_relation(&event), None); // If the event has a relation, but it's not a thread reply, it returns None. let event = Raw::new(&json!({ @@ -238,6 +266,11 @@ mod tests { let observed_thread_root = extract_thread_root(&event); assert_matches!(observed_thread_root, None); + let observed_relation = extract_relation(&event).unwrap(); + assert_eq!( + observed_relation, + (RelationsType::Reference, owned_event_id!("$referenced_event_id:example.com")) + ); } #[test] diff --git a/crates/matrix-sdk/src/event_cache/room/pinned_events.rs b/crates/matrix-sdk/src/event_cache/room/pinned_events.rs index d4bdc107d..295fa2111 100644 --- a/crates/matrix-sdk/src/event_cache/room/pinned_events.rs +++ b/crates/matrix-sdk/src/event_cache/room/pinned_events.rs @@ -24,6 +24,7 @@ use matrix_sdk_base::{ store::{EventCacheStoreLock, EventCacheStoreLockGuard, EventCacheStoreLockState}, }, linked_chunk::{LinkedChunkId, OwnedLinkedChunkId, Update}, + serde_helpers::{RelationsType, extract_relation}, task_monitor::BackgroundTaskHandle, }; #[cfg(feature = "e2e-encryption")] @@ -36,7 +37,6 @@ use ruma::{ room_version_rules::RedactionRules, serde::Raw, }; -use serde::Deserialize; use tokio::sync::{ Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard, broadcast::{Receiver, Sender}, @@ -553,24 +553,15 @@ impl PinnedEventCache { /// Given a raw event, try to extract the target event ID of a relation as /// defined with `m.relates_to`. fn extract_relation_target(raw: &Raw) -> Option { - #[derive(Deserialize)] - struct SimplifiedRelation { - event_id: OwnedEventId, + let (rel_type, event_id) = extract_relation(raw)?; + + // Don't include thread responses in the pinned event chunk. + match rel_type { + RelationsType::Thread => None, + RelationsType::Edit | RelationsType::Annotation | RelationsType::Reference => { + Some(event_id) + } } - - #[derive(Deserialize)] - struct SimplifiedContent { - #[serde(rename = "m.relates_to")] - relates_to: Option, - } - - let SimplifiedContent { relates_to: Some(relation) } = - raw.get_field::("content").ok()?? - else { - return None; - }; - - Some(relation.event_id) } /// Given a raw event, try to extract the target event ID of a live diff --git a/crates/matrix-sdk/src/room/mod.rs b/crates/matrix-sdk/src/room/mod.rs index f1a4471e1..491ca25db 100644 --- a/crates/matrix-sdk/src/room/mod.rs +++ b/crates/matrix-sdk/src/room/mod.rs @@ -46,6 +46,7 @@ use matrix_sdk_base::{ RawAnySyncOrStrippedState, RawSyncOrStrippedState, SyncOrStrippedState, }, media::{MediaThumbnailSettings, store::IgnoreMediaRetentionPolicy}, + serde_helpers::{RelationsType, extract_relation}, store::{StateStoreExt, ThreadSubscriptionStatus}, }; #[cfg(feature = "e2e-encryption")] @@ -854,8 +855,22 @@ impl Room { request_config: Option, ) -> Result<(TimelineEvent, Vec)> { let fetch_relations = async || { + // If there's only a single filter, we can use a more efficient request, + // specialized on the filter type. + // + // Otherwise, we need to get all the relations: + // - either because no filters implies we fetch all relations, + // - or because there are multiple filters and we must filter out manually. + let include_relations = if let Some(filter) = &filter + && filter.len() == 1 + { + IncludeRelations::RelationsOfType(filter[0].clone()) + } else { + IncludeRelations::AllRelations + }; + let mut opts = RelationsOptions { - include_relations: IncludeRelations::AllRelations, + include_relations, recurse: true, limit: Some(uint!(256)), ..Default::default() @@ -865,7 +880,32 @@ impl Room { loop { match self.relations(event_id.to_owned(), opts.clone()).await { Ok(relations) => { - events.extend(relations.chunk); + if let Some(filter) = filter.as_ref() { + // Manually filter out the relation types we're interested in. + events.extend(relations.chunk.into_iter().filter_map(|ev| { + let (rel_type, _) = extract_relation(ev.raw())?; + filter + .iter() + .any(|ruma_filter| match ruma_filter { + RelationType::Annotation => { + rel_type == RelationsType::Annotation + } + RelationType::Replacement => { + rel_type == RelationsType::Edit + } + RelationType::Thread => rel_type == RelationsType::Thread, + RelationType::Reference => { + rel_type == RelationsType::Reference + } + _ => false, + }) + .then_some(ev) + })); + } else { + // No filter: include all events from the response. + events.extend(relations.chunk); + } + if let Some(next_from) = relations.next_batch_token { opts.from = Some(next_from); } else { @@ -914,7 +954,8 @@ impl Room { // Try to get the relations from the event cache (if we have one). if let Some((event_cache, _drop_handles)) = event_cache - && let Some(relations) = event_cache.find_event_relations(event_id, filter).await.ok() + && let Some(relations) = + event_cache.find_event_relations(event_id, filter.clone()).await.ok() && !relations.is_empty() { return Ok((event, relations)); diff --git a/crates/matrix-sdk/tests/integration/room/pinned_events.rs b/crates/matrix-sdk/tests/integration/room/pinned_events.rs index 16b9bcfeb..1f84df56a 100644 --- a/crates/matrix-sdk/tests/integration/room/pinned_events.rs +++ b/crates/matrix-sdk/tests/integration/room/pinned_events.rs @@ -1,8 +1,12 @@ use std::{ops::Not as _, sync::Arc}; use matrix_sdk::{ - Room, event_cache::RoomEventCacheUpdate, store::StoreConfig, - test_utils::mocks::MatrixMockServer, timeout::timeout, + Room, + event_cache::RoomEventCacheUpdate, + room::IncludeRelations, + store::StoreConfig, + test_utils::mocks::{MatrixMockServer, RoomRelationsResponseTemplate}, + timeout::timeout, }; use matrix_sdk_base::event_cache::store::MemoryStore; use matrix_sdk_test::{JoinedRoomBuilder, StateTestEvent, async_test, event_factory::EventFactory}; @@ -267,3 +271,91 @@ async fn test_pinned_events_are_reloaded_from_storage() { "The pinned event should have been reloaded from storage" ); } + +#[async_test] +async fn test_pinned_events_dont_include_thread_responses() { + let room_id = room_id!("!galette:saucisse.bzh"); + let pinned_event_id = event_id!("$pinned_event"); + + let f = EventFactory::new().room(room_id).sender(user_id!("@alice:example.org")); + + // Create the pinned event, that's a thread root. + let pinned_event = f.text_msg("I'm pinned!").event_id(pinned_event_id).into_event(); + + let thread_response = f + .text_msg("I'm a thread response!") + .in_thread(pinned_event_id, pinned_event_id) + .into_event(); + + // Create a first client. + let server = MatrixMockServer::new().await; + + server.mock_room_event().match_event_id().ok(pinned_event.clone()).mock_once().mount().await; + + // Serve a thread relation over network; it should NOT be included in the pinned + // event cache for that room. + server + .mock_room_relations() + .match_subrequest(IncludeRelations::AllRelations) + .match_target_event(pinned_event_id.to_owned()) + .ok(RoomRelationsResponseTemplate::default() + .events(vec![thread_response.raw().clone().cast_unchecked()])) + .mock_once() + .mount() + .await; + + let client = server.client_builder().build().await; + + // Subscribe the event cache to sync updates. + client.event_cache().subscribe().unwrap(); + + // Sync the room with the pinned event ID in the room state. + // + // This is important: the pinned events list must include our event ID, + // otherwise the initial reload from network will clear the storage-loaded + // events. + let pinned_events_state = StateTestEvent::Custom(json!({ + "content": { + "pinned": [pinned_event_id] + }, + "event_id": "$pinned_events_state", + "origin_server_ts": 151393755, + "sender": "@example:localhost", + "state_key": "", + "type": "m.room.pinned_events", + })); + + let room = server + .sync_room(&client, JoinedRoomBuilder::new(room_id).add_state_event(pinned_events_state)) + .await; + + // Get the room event cache and subscribe to pinned events. + let (room_event_cache, _drop_handles) = room.event_cache().await.unwrap(); + + // Subscribe to pinned events - this triggers PinnedEventCache::new() which + // spawns a task that calls reload_from_storage() first. + let (events, mut subscriber) = room_event_cache.subscribe_to_pinned_events().await.unwrap(); + let mut events = events.into(); + + // Wait for the background task to reload the events. + while let Ok(Ok(up)) = timeout(subscriber.recv(), Duration::from_millis(300)).await { + if let RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. } = up { + for diff in diffs { + diff.apply(&mut events); + } + } + if !events.is_empty() { + break; + } + } + + // Verify the pinned event was loaded from the network, and that there wasn't + // any other event loaded (in particular, the thread response shouldn't be + // included in the pinned events). + assert_eq!(events.len(), 1, "Expected pinned events to be loaded from network"); + assert_eq!( + events[0].event_id().unwrap(), + pinned_event_id, + "The pinned event should have been loaded from network" + ); +}