diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index 8328eeeeb..e2e02f9fc 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -17,7 +17,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["e2e-encryption", "automatic-room-key-forwarding", "sqlite", "native-tls"] -testing = ["matrix-sdk-sqlite?/testing", "matrix-sdk-indexeddb?/testing", "matrix-sdk-base/testing", "wiremock"] +testing = ["matrix-sdk-sqlite?/testing", "matrix-sdk-indexeddb?/testing", "matrix-sdk-base/testing", "wiremock", "matrix-sdk-test", "assert_matches2"] e2e-encryption = [ "matrix-sdk-base/e2e-encryption", @@ -65,6 +65,7 @@ docsrs = ["e2e-encryption", "sqlite", "indexeddb", "sso-login", "qrcode", "image anyhow = { workspace = true, optional = true } anymap2 = "0.13.0" aquamarine = "0.5.0" +assert_matches2 = { workspace = true, optional = true } as_variant = { workspace = true } async-channel = "2.1.0" async-stream = { workspace = true } @@ -91,6 +92,7 @@ matrix-sdk-base = { workspace = true } matrix-sdk-common = { workspace = true } matrix-sdk-indexeddb = { workspace = true, optional = true } matrix-sdk-sqlite = { workspace = true, optional = true } +matrix-sdk-test = { workspace = true, optional = true } mime = "0.3.16" mime2ext = "0.1.52" rand = { workspace = true , optional = true } diff --git a/crates/matrix-sdk/src/event_cache/mod.rs b/crates/matrix-sdk/src/event_cache/mod.rs index 1d8985409..e7e065f8f 100644 --- a/crates/matrix-sdk/src/event_cache/mod.rs +++ b/crates/matrix-sdk/src/event_cache/mod.rs @@ -75,6 +75,7 @@ use self::{ use crate::{client::ClientInner, room::MessagesOptions, Client, Room}; mod linked_chunk; +pub mod paginator; mod store; /// An error observed in the [`EventCache`]. diff --git a/crates/matrix-sdk/src/event_cache/paginator.rs b/crates/matrix-sdk/src/event_cache/paginator.rs new file mode 100644 index 000000000..a10285d72 --- /dev/null +++ b/crates/matrix-sdk/src/event_cache/paginator.rs @@ -0,0 +1,712 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! The paginator is a stateful helper object that handles reaching an event, +//! either from a cache or network, and surrounding events ("context"). Then, it +//! makes it possible to paginate forward or backward, from that event, until +//! one end of the timeline (front or back) is reached. + +use eyeball::{SharedObservable, Subscriber}; +use matrix_sdk_base::{deserialized_responses::TimelineEvent, SendOutsideWasm, SyncOutsideWasm}; +use ruma::{api::Direction, uint, EventId, OwnedEventId}; +use tokio::sync::Mutex; + +use crate::{ + room::{EventWithContextResponse, Messages, MessagesOptions}, + Room, +}; + +/// Current state of a [`Paginator`]. +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum PaginatorState { + /// The initial state of the paginator. + Initial, + + /// The paginator is fetching the target initial event. + FetchingTargetEvent, + + /// The target initial event could be found, zero or more paginations have + /// happened since then, and the paginator is at rest now. + Idle, + + /// The paginator is… paginating one direction or another. + Paginating, +} + +/// An error that happened when using a [`Paginator`]. +#[derive(Debug, thiserror::Error)] +pub enum PaginatorError { + /// The target event could not be found. + #[error("target event with id {0} could not be found")] + EventNotFound(OwnedEventId), + + /// We're trying to manipulate the paginator in the wrong state. + #[error("expected paginator state {expected:?}, observed {actual:?}")] + InvalidPreviousState { + /// The state we were expecting to see. + expected: PaginatorState, + /// The actual state when doing the check. + actual: PaginatorState, + }, + + /// There was another SDK error while paginating. + #[error("an error happened while paginating")] + SdkError(#[source] crate::Error), +} + +/// A stateful object to reach to an event, and then paginate backward and +/// forward from it. +/// +/// See also the module-level documentation. +pub struct Paginator { + /// The room in which we're going to run the pagination. + room: Box, + + /// Current state of the paginator. + state: SharedObservable, + + /// The token to run the next backward pagination. + prev_batch_token: Mutex>, + + /// The token to run the next forward pagination. + next_batch_token: Mutex>, +} + +#[cfg(not(tarpaulin_include))] +impl std::fmt::Debug for Paginator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Don't include the room in the debug output. + f.debug_struct("Paginator") + .field("state", &self.state.get()) + .field("prev_batch_token", &self.prev_batch_token) + .field("next_batch_token", &self.next_batch_token) + .finish_non_exhaustive() + } +} + +/// The result of a single pagination, be it from +/// [`Paginator::paginate_backward`] or [`Paginator::paginate_forward`]. +#[derive(Debug)] +pub struct PaginationResult { + /// Events returned during this pagination. + /// + /// If this is the result of a backward pagination, then the events are in + /// reverse topological order. + /// + /// If this is the result of a forward pagination, then the events are in + /// topological order. + pub events: Vec, + + /// Did we hit an end of the timeline? + /// + /// If this is the result of a backward pagination, this means we hit the + /// *start* of the timeline. + /// + /// If this is the result of a forward pagination, this means we hit the + /// *end* of the timeline. + pub hit_end_of_timeline: bool, +} + +/// The result of an initial [`Paginator::start_from`] query. +#[derive(Debug)] +pub struct StartFromResult { + /// All the events returned during this pagination, in topological ordering. + pub events: Vec, + + /// Whether the /context query returned a previous batch token. + pub has_prev: bool, + + /// Whether the /context query returned a next batch token. + pub has_next: bool, +} + +impl Paginator { + /// Create a new [`Paginator`], given a room implementation. + pub fn new(room: Box) -> Self { + Self { + room, + state: SharedObservable::new(PaginatorState::Initial), + prev_batch_token: Mutex::new(None), + next_batch_token: Mutex::new(None), + } + } + + /// Check if the current state of the paginator matches the expected one. + fn check_state(&self, expected: PaginatorState) -> Result<(), PaginatorError> { + let actual = self.state.get(); + if actual != expected { + Err(PaginatorError::InvalidPreviousState { expected, actual }) + } else { + Ok(()) + } + } + + /// Returns a subscriber to the internal [`PaginatorState`] machine. + pub fn state(&self) -> Subscriber { + self.state.subscribe() + } + + /// Starts the pagination from the initial event. + /// + /// Only works for fresh [`Paginator`] objects, which are in the + /// [`PaginatorState::Initial`] state. + pub async fn start_from(&self, event_id: &EventId) -> Result { + self.check_state(PaginatorState::Initial)?; + + // Note: it's possible two callers have checked the state and both figured it's + // initial. This check below makes sure there's at most one which can set the + // state to FetchingTargetEvent, preventing a race condition. + if self.state.set_if_not_eq(PaginatorState::FetchingTargetEvent).is_none() { + return Err(PaginatorError::InvalidPreviousState { + expected: PaginatorState::Initial, + actual: PaginatorState::FetchingTargetEvent, + }); + } + + // TODO: do we want to lazy load members? + let lazy_load_members = true; + + let response = self.room.event_with_context(event_id, lazy_load_members).await?; + + let has_prev = response.prev_batch_token.is_some(); + let has_next = response.next_batch_token.is_some(); + *self.prev_batch_token.lock().await = response.prev_batch_token; + *self.next_batch_token.lock().await = response.next_batch_token; + + self.state.set(PaginatorState::Idle); + + // Consolidate the events into a linear timeline, topologically ordered. + // - the events before are returned in the reverse topological order: invert + // them. + // - insert the target event, if set. + // - the events after are returned in the correct topological order. + + let events = response + .events_before + .into_iter() + .rev() + .chain(response.event) + .chain(response.events_after) + .collect(); + + Ok(StartFromResult { events, has_prev, has_next }) + } + + /// Runs a backward pagination, from the current state of the object. + /// + /// Will return immediately if we have already hit the start of the + /// timeline. + /// + /// May return an error if it's already paginating, or if the call to + /// /messages failed. + pub async fn paginate_backward(&self) -> Result { + self.paginate(Direction::Backward, &self.prev_batch_token).await + } + + /// Runs a forward pagination, from the current state of the object. + /// + /// Will return immediately if we have already hit the end of the timeline. + /// + /// May return an error if it's already paginating, or if the call to + /// /messages failed. + pub async fn paginate_forward(&self) -> Result { + self.paginate(Direction::Forward, &self.next_batch_token).await + } + + async fn paginate( + &self, + dir: Direction, + token_lock: &Mutex>, + ) -> Result { + self.check_state(PaginatorState::Idle)?; + + let token = { + let token = token_lock.lock().await; + if token.is_none() { + return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true }); + }; + token.clone() + }; + + // Note: it's possible two callers have checked the state and both figured it's + // idle. This check below makes sure there's at most one which can set the + // state to paginating, preventing a race condition. + if self.state.set_if_not_eq(PaginatorState::Paginating).is_none() { + return Err(PaginatorError::InvalidPreviousState { + expected: PaginatorState::Idle, + actual: PaginatorState::Paginating, + }); + } + + let response_result = + self.room.messages(MessagesOptions::new(dir).from(token.as_deref())).await; + + // In case of error, reset the state to idle. + let response = match response_result { + Ok(res) => res, + Err(err) => { + self.state.set(PaginatorState::Idle); + return Err(err); + } + }; + + let hit_end_of_timeline = response.end.is_none(); + *token_lock.lock().await = response.end; + + // TODO: what to do with state events? + + self.state.set(PaginatorState::Idle); + + Ok(PaginationResult { events: response.chunk, hit_end_of_timeline }) + } +} + +/// A room that can be paginated. +/// +/// Not [`crate::Room`] because we may want to paginate rooms we don't belong +/// to. +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +pub trait PaginableRoom: SendOutsideWasm + SyncOutsideWasm { + /// Runs a /context query for the given room. + /// + /// Must return [`PaginatorError::EventNotFound`] whenever the target event + /// could not be found, instead of causing an http `Err` result. + async fn event_with_context( + &self, + event_id: &EventId, + lazy_load_members: bool, + ) -> Result; + + /// Runs a /messages query for the given room. + async fn messages(&self, opts: MessagesOptions) -> Result; +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl PaginableRoom for Room { + async fn event_with_context( + &self, + event_id: &EventId, + lazy_load_members: bool, + ) -> Result { + let response = match self.event_with_context(event_id, lazy_load_members, uint!(20)).await { + Ok(result) => result, + + Err(err) => { + // If the error was a 404, then the event wasn't found on the server; special + // case this to make it easy to react to such an error. + if let Some(error) = err.as_client_api_error() { + if error.status_code == 404 { + // Event not found + return Err(PaginatorError::EventNotFound(event_id.to_owned())); + } + } + + // Otherwise, just return a wrapped error. + return Err(PaginatorError::SdkError(err)); + } + }; + + Ok(response) + } + + async fn messages(&self, opts: MessagesOptions) -> Result { + self.messages(opts).await.map_err(PaginatorError::SdkError) + } +} + +#[cfg(all(not(target_arch = "wasm32"), test))] +mod tests { + use std::sync::Arc; + + use assert_matches2::assert_let; + use async_trait::async_trait; + use futures_core::Future; + use futures_util::FutureExt as _; + use matrix_sdk_test::async_test; + use once_cell::sync::Lazy; + use ruma::{event_id, room_id, user_id, RoomId, UserId}; + use tokio::{spawn, sync::Notify}; + + use super::*; + use crate::test_utils::{assert_event_matches_msg, events::EventFactory}; + + #[derive(Clone)] + struct TestRoom { + event_factory: Arc, + wait_for_ready: bool, + + target_event_text: Arc>, + next_events: Arc>>, + prev_events: Arc>>, + prev_batch_token: Arc>>, + next_batch_token: Arc>>, + + room_ready: Arc, + } + + impl TestRoom { + fn new(wait_for_ready: bool, room_id: &RoomId, sender: &UserId) -> Self { + let event_factory = Arc::new(EventFactory::default().sender(sender).room(room_id)); + + Self { + event_factory, + wait_for_ready, + + room_ready: Default::default(), + target_event_text: Default::default(), + next_events: Default::default(), + prev_events: Default::default(), + prev_batch_token: Default::default(), + next_batch_token: Default::default(), + } + } + + /// Unblocks the next request. + fn mark_ready(&self) { + self.room_ready.notify_one(); + } + } + + static ROOM_ID: Lazy<&RoomId> = Lazy::new(|| room_id!("!dune:herbert.org")); + static USER_ID: Lazy<&UserId> = Lazy::new(|| user_id!("@paul:atreid.es")); + + #[async_trait] + impl PaginableRoom for TestRoom { + async fn event_with_context( + &self, + event_id: &EventId, + _lazy_load_members: bool, + ) -> Result { + // Wait for the room to be marked as ready first. + if self.wait_for_ready { + self.room_ready.notified().await; + } + + let event = self + .event_factory + .text_msg(self.target_event_text.lock().await.clone()) + .event_id(event_id) + .into_timeline(); + + return Ok(EventWithContextResponse { + event: Some(event), + events_before: self.prev_events.lock().await.clone(), + events_after: self.next_events.lock().await.clone(), + prev_batch_token: self.prev_batch_token.lock().await.clone(), + next_batch_token: self.next_batch_token.lock().await.clone(), + state: Vec::new(), + }); + } + + async fn messages(&self, opts: MessagesOptions) -> Result { + if self.wait_for_ready { + self.room_ready.notified().await; + } + + let (end, events) = match opts.dir { + Direction::Backward => ( + self.prev_batch_token.lock().await.clone(), + self.prev_events.lock().await.clone(), + ), + Direction::Forward => ( + self.next_batch_token.lock().await.clone(), + self.next_events.lock().await.clone(), + ), + }; + + return Ok(Messages { + start: opts.from.unwrap().to_owned(), + end, + chunk: events, + state: Vec::new(), + }); + } + } + + async fn assert_invalid_state( + task: impl Future>, + expected: PaginatorState, + actual: PaginatorState, + ) { + assert_let!( + Err(PaginatorError::InvalidPreviousState { + expected: real_expected, + actual: real_actual + }) = task.await + ); + assert_eq!(real_expected, expected); + assert_eq!(real_actual, actual); + } + + #[async_test] + async fn test_start_from() { + // Prepare test data. + let room = Box::new(TestRoom::new(false, *ROOM_ID, *USER_ID)); + + let event_id = event_id!("$yoyoyo"); + let event_factory = &room.event_factory; + + *room.target_event_text.lock().await = "fetch_from".to_owned(); + *room.prev_events.lock().await = (0..10) + .rev() + .map(|i| { + TimelineEvent::new( + event_factory.text_msg(format!("before-{i}")).into_raw_timeline(), + ) + }) + .collect(); + *room.next_events.lock().await = (0..10) + .map(|i| { + TimelineEvent::new(event_factory.text_msg(format!("after-{i}")).into_raw_timeline()) + }) + .collect(); + + // When I call `Paginator::start_from`, it works, + let paginator = Arc::new(Paginator::new(room.clone())); + let context = paginator.start_from(event_id).await.expect("start_from should work"); + + assert!(!context.has_prev); + assert!(!context.has_next); + + // And I get the events I expected. + + // 10 events before, the target event, 10 events after. + assert_eq!(context.events.len(), 21); + + for i in 0..10 { + assert_event_matches_msg(&context.events[i], &format!("before-{i}")); + } + + assert_event_matches_msg(&context.events[10], "fetch_from"); + assert_eq!(context.events[10].event.deserialize().unwrap().event_id(), event_id); + + for i in 0..10 { + assert_event_matches_msg(&context.events[i + 11], &format!("after-{i}")); + } + } + + #[async_test] + async fn test_paginate_backward() { + // Prepare test data. + let room = Box::new(TestRoom::new(false, *ROOM_ID, *USER_ID)); + + let event_id = event_id!("$yoyoyo"); + let event_factory = &room.event_factory; + + *room.target_event_text.lock().await = "initial".to_owned(); + *room.prev_batch_token.lock().await = Some("prev".to_owned()); + + // When I call `Paginator::start_from`, it works, + let paginator = Arc::new(Paginator::new(room.clone())); + let context = paginator.start_from(event_id).await.expect("start_from should work"); + + // And I get the events I expected. + assert_eq!(context.events.len(), 1); + assert_event_matches_msg(&context.events[0], "initial"); + assert_eq!(context.events[0].event.deserialize().unwrap().event_id(), event_id); + + // There's a previous batch. + assert!(context.has_prev); + assert!(!context.has_next); + + // Preparing data for the next back-pagination. + *room.prev_events.lock().await = vec![event_factory.text_msg("previous").into_timeline()]; + *room.prev_batch_token.lock().await = Some("prev2".to_owned()); + + // When I backpaginate, I get the events I expect. + let prev = paginator.paginate_backward().await.expect("paginate backward should work"); + assert!(!prev.hit_end_of_timeline); + assert_eq!(prev.events.len(), 1); + assert_event_matches_msg(&prev.events[0], "previous"); + + // And I can backpaginate again, because there's a prev batch token + // still. + *room.prev_events.lock().await = vec![event_factory.text_msg("oldest").into_timeline()]; + *room.prev_batch_token.lock().await = None; + + let prev = paginator + .paginate_backward() + .await + .expect("paginate backward the second time should work"); + assert!(prev.hit_end_of_timeline); + assert_eq!(prev.events.len(), 1); + assert_event_matches_msg(&prev.events[0], "oldest"); + + // I've hit the start of the timeline, but back-paginating again will + // return immediately. + let prev = paginator + .paginate_backward() + .await + .expect("paginate backward the third time should work"); + assert!(prev.hit_end_of_timeline); + assert!(prev.events.is_empty()); + } + + #[async_test] + async fn test_paginate_forward() { + // Prepare test data. + let room = Box::new(TestRoom::new(false, *ROOM_ID, *USER_ID)); + + let event_id = event_id!("$yoyoyo"); + let event_factory = &room.event_factory; + + *room.target_event_text.lock().await = "initial".to_owned(); + *room.next_batch_token.lock().await = Some("next".to_owned()); + + // When I call `Paginator::start_from`, it works, + let paginator = Arc::new(Paginator::new(room.clone())); + let context = paginator.start_from(event_id).await.expect("start_from should work"); + + // And I get the events I expected. + assert_eq!(context.events.len(), 1); + assert_event_matches_msg(&context.events[0], "initial"); + assert_eq!(context.events[0].event.deserialize().unwrap().event_id(), event_id); + + // There's a next batch. + assert!(!context.has_prev); + assert!(context.has_next); + + // Preparing data for the next forward-pagination. + *room.next_events.lock().await = vec![event_factory.text_msg("next").into_timeline()]; + *room.next_batch_token.lock().await = Some("next2".to_owned()); + + // When I forward-paginate, I get the events I expect. + let next = paginator.paginate_forward().await.expect("paginate forward should work"); + assert!(!next.hit_end_of_timeline); + assert_eq!(next.events.len(), 1); + assert_event_matches_msg(&next.events[0], "next"); + + // And I can forward-paginate again, because there's a prev batch token + // still. + *room.next_events.lock().await = vec![event_factory.text_msg("latest").into_timeline()]; + *room.next_batch_token.lock().await = None; + + let next = paginator + .paginate_forward() + .await + .expect("paginate forward the second time should work"); + assert!(next.hit_end_of_timeline); + assert_eq!(next.events.len(), 1); + assert_event_matches_msg(&next.events[0], "latest"); + + // I've hit the start of the timeline, but back-paginating again will + // return immediately. + let next = paginator + .paginate_forward() + .await + .expect("paginate forward the third time should work"); + assert!(next.hit_end_of_timeline); + assert!(next.events.is_empty()); + } + + #[async_test] + async fn test_state() { + let room = Box::new(TestRoom::new(true, *ROOM_ID, *USER_ID)); + + *room.prev_batch_token.lock().await = Some("prev".to_owned()); + *room.next_batch_token.lock().await = Some("next".to_owned()); + + let paginator = Arc::new(Paginator::new(room.clone())); + + let event_id = event_id!("$yoyoyo"); + + let mut state = paginator.state(); + + assert_eq!(state.get(), PaginatorState::Initial); + assert!(state.next().now_or_never().is_none()); + + // Attempting to run pagination must fail and not change the state. + assert_invalid_state( + paginator.paginate_backward(), + PaginatorState::Idle, + PaginatorState::Initial, + ) + .await; + + assert!(state.next().now_or_never().is_none()); + + // Running the initial query must work. + let p = paginator.clone(); + let join_handle = spawn(async move { p.start_from(event_id).await }); + + assert_eq!(state.next().await, Some(PaginatorState::FetchingTargetEvent)); + assert!(state.next().now_or_never().is_none()); + + // The query is pending. Running other operations must fail. + assert_invalid_state( + paginator.start_from(event_id), + PaginatorState::Initial, + PaginatorState::FetchingTargetEvent, + ) + .await; + + assert_invalid_state( + paginator.paginate_backward(), + PaginatorState::Idle, + PaginatorState::FetchingTargetEvent, + ) + .await; + + assert!(state.next().now_or_never().is_none()); + + // Mark the dummy room as ready. The query may now terminate. + room.mark_ready(); + + // After fetching the initial event data, the paginator switches to `Idle`. + assert_eq!(state.next().await, Some(PaginatorState::Idle)); + + join_handle.await.expect("joined failed").expect("/context failed"); + + assert!(state.next().now_or_never().is_none()); + + let p = paginator.clone(); + let join_handle = spawn(async move { p.paginate_backward().await }); + + assert_eq!(state.next().await, Some(PaginatorState::Paginating)); + + // The query is pending. Running other operations must fail. + assert_invalid_state( + paginator.start_from(event_id), + PaginatorState::Initial, + PaginatorState::Paginating, + ) + .await; + + assert_invalid_state( + paginator.paginate_backward(), + PaginatorState::Idle, + PaginatorState::Paginating, + ) + .await; + + assert_invalid_state( + paginator.paginate_forward(), + PaginatorState::Idle, + PaginatorState::Paginating, + ) + .await; + + assert!(state.next().now_or_never().is_none()); + + room.mark_ready(); + + assert_eq!(state.next().await, Some(PaginatorState::Idle)); + + join_handle.await.expect("joined failed").expect("/messages failed"); + + assert!(state.next().now_or_never().is_none()); + } +} diff --git a/crates/matrix-sdk/src/room/mod.rs b/crates/matrix-sdk/src/room/mod.rs index 930b19c49..cccb16a5d 100644 --- a/crates/matrix-sdk/src/room/mod.rs +++ b/crates/matrix-sdk/src/room/mod.rs @@ -83,13 +83,10 @@ use thiserror::Error; use tokio::sync::broadcast; use tracing::{debug, info, instrument, warn}; -use self::{ - futures::{SendAttachment, SendMessageLikeEvent, SendRawMessageLikeEvent}, - messages::EventWithContextResponse, -}; +use self::futures::{SendAttachment, SendMessageLikeEvent, SendRawMessageLikeEvent}; pub use self::{ member::{RoomMember, RoomMemberRole}, - messages::{Messages, MessagesOptions}, + messages::{EventWithContextResponse, Messages, MessagesOptions}, }; #[cfg(doc)] use crate::event_cache::EventCache; diff --git a/crates/matrix-sdk/src/test_utils.rs b/crates/matrix-sdk/src/test_utils.rs index 63b0f8d34..a9f2a60d7 100644 --- a/crates/matrix-sdk/src/test_utils.rs +++ b/crates/matrix-sdk/src/test_utils.rs @@ -2,16 +2,37 @@ #![allow(dead_code)] -use matrix_sdk_base::SessionMeta; -use ruma::{api::MatrixVersion, device_id, user_id}; +use assert_matches2::assert_let; +use matrix_sdk_base::{deserialized_responses::SyncTimelineEvent, SessionMeta}; +use ruma::{ + api::MatrixVersion, + device_id, + events::{room::message::MessageType, AnySyncMessageLikeEvent, AnySyncTimelineEvent}, + user_id, +}; use url::Url; +pub mod events; + use crate::{ config::RequestConfig, matrix_auth::{MatrixSession, MatrixSessionTokens}, Client, ClientBuilder, }; +/// Checks that an event is a message-like text event with the given text. +#[track_caller] +pub fn assert_event_matches_msg>(event: &E, expected: &str) { + let event: SyncTimelineEvent = event.clone().into(); + let event = event.event.deserialize().unwrap(); + assert_let!( + AnySyncTimelineEvent::MessageLike(AnySyncMessageLikeEvent::RoomMessage(message)) = event + ); + let message = message.as_original().unwrap(); + assert_let!(MessageType::Text(text) = &message.content.msgtype); + assert_eq!(text.body, expected); +} + /// A [`ClientBuilder`] fit for testing, using the given `homeserver_url` (or /// localhost:1234). pub fn test_client_builder(homeserver_url: Option) -> ClientBuilder { diff --git a/crates/matrix-sdk/src/test_utils/events.rs b/crates/matrix-sdk/src/test_utils/events.rs new file mode 100644 index 000000000..f57795d50 --- /dev/null +++ b/crates/matrix-sdk/src/test_utils/events.rs @@ -0,0 +1,150 @@ +// Copyright 2024 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(missing_docs)] + +use std::sync::atomic::{AtomicU64, Ordering::SeqCst}; + +use matrix_sdk_base::deserialized_responses::{SyncTimelineEvent, TimelineEvent}; +use matrix_sdk_test::{sync_timeline_event, timeline_event}; +use ruma::{ + events::{ + room::message::RoomMessageEventContent, AnySyncTimelineEvent, AnyTimelineEvent, + EventContent, + }, + serde::Raw, + server_name, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, + RoomId, UserId, +}; +use serde::Serialize; + +#[derive(Debug)] +pub struct EventBuilder { + sender: Option, + room: Option, + event_id: Option, + content: E, + server_ts: MilliSecondsSinceUnixEpoch, +} + +impl EventBuilder +where + E::EventType: Serialize, +{ + pub fn room(mut self, room_id: &RoomId) -> Self { + self.room = Some(room_id.to_owned()); + self + } + + pub fn sender(mut self, sender: &UserId) -> Self { + self.sender = Some(sender.to_owned()); + self + } + + pub fn event_id(mut self, event_id: &EventId) -> Self { + self.event_id = Some(event_id.to_owned()); + self + } + + pub fn into_raw_timeline(self) -> Raw { + let room_id = self.room.expect("we should have a room id at this point"); + let event_id = + self.event_id.unwrap_or_else(|| EventId::new(room_id.server_name().unwrap())); + + timeline_event!({ + "type": self.content.event_type(), + "content": self.content, + "event_id": event_id, + "sender": self.sender.expect("we should have a sender user id at this point"), + "room_id": room_id, + "origin_server_ts": self.server_ts + }) + } + + pub fn into_timeline(self) -> TimelineEvent { + TimelineEvent::new(self.into_raw_timeline()) + } + + pub fn into_raw_sync(self) -> Raw { + let event_id = self + .event_id + .or_else(|| self.room.map(|room_id| EventId::new(room_id.server_name().unwrap()))) + .unwrap_or_else(|| EventId::new(server_name!("dummy.org"))); + + sync_timeline_event!({ + "type": self.content.event_type(), + "content": self.content, + "event_id": event_id, + "sender": self.sender.expect("we should have a sender user id at this point"), + "origin_server_ts": self.server_ts + }) + } + + pub fn into_sync(self) -> SyncTimelineEvent { + SyncTimelineEvent::new(self.into_raw_sync()) + } +} + +#[derive(Debug, Default)] +pub struct EventFactory { + next_ts: AtomicU64, + sender: Option, + room: Option, +} + +impl EventFactory { + pub fn new() -> Self { + Self { next_ts: AtomicU64::new(0), sender: None, room: None } + } + + pub fn room(mut self, room_id: &RoomId) -> Self { + self.room = Some(room_id.to_owned()); + self + } + + pub fn sender(mut self, sender: &UserId) -> Self { + self.sender = Some(sender.to_owned()); + self + } + + fn next_server_ts(&self) -> MilliSecondsSinceUnixEpoch { + MilliSecondsSinceUnixEpoch( + self.next_ts + .fetch_add(1, SeqCst) + .try_into() + .expect("server timestamp should fit in js_int::UInt"), + ) + } + + pub fn event(&self, content: E) -> EventBuilder { + EventBuilder { + sender: self.sender.clone(), + room: self.room.clone(), + server_ts: self.next_server_ts(), + event_id: None, + content, + } + } + + pub fn text_msg(&self, content: impl Into) -> EventBuilder { + self.event(RoomMessageEventContent::text_plain(content.into())) + } + + /// Set the next server timestamp. + /// + /// Timestamps will continue to increase by 1 (millisecond) from that value. + pub fn set_next_ts(&self, value: u64) { + self.next_ts.store(value, SeqCst); + } +} diff --git a/crates/matrix-sdk/tests/integration/event_cache.rs b/crates/matrix-sdk/tests/integration/event_cache.rs index dae21ab0d..c8c30bc4a 100644 --- a/crates/matrix-sdk/tests/integration/event_cache.rs +++ b/crates/matrix-sdk/tests/integration/event_cache.rs @@ -3,7 +3,7 @@ use std::time::Duration; use assert_matches2::{assert_let, assert_matches}; use matrix_sdk::{ event_cache::{BackPaginationOutcome, EventCacheError, RoomEventCacheUpdate}, - test_utils::logged_in_client_with_server, + test_utils::{assert_event_matches_msg, logged_in_client_with_server}, }; use matrix_sdk_common::deserialized_responses::SyncTimelineEvent; use matrix_sdk_test::{ @@ -11,10 +11,7 @@ use matrix_sdk_test::{ }; use ruma::{ event_id, - events::{ - room::message::{MessageType, RoomMessageEventContent}, - AnySyncMessageLikeEvent, AnySyncTimelineEvent, AnyTimelineEvent, - }, + events::{room::message::RoomMessageEventContent, AnyTimelineEvent}, room_id, serde::Raw, user_id, @@ -28,17 +25,6 @@ use wiremock::{ use crate::mock_sync; -#[track_caller] -fn assert_event_matches_msg(event: &SyncTimelineEvent, expected: &str) { - let event = event.event.deserialize().unwrap(); - assert_let!( - AnySyncTimelineEvent::MessageLike(AnySyncMessageLikeEvent::RoomMessage(message)) = event - ); - let message = message.as_original().unwrap(); - assert_let!(MessageType::Text(text) = &message.content.msgtype); - assert_eq!(text.body, expected); -} - #[async_test] async fn test_must_explicitly_subscribe() { let (client, server) = logged_in_client_with_server().await; @@ -290,8 +276,8 @@ async fn test_backpaginate_once() { assert_let!(BackPaginationOutcome::Success { events, reached_start } = outcome); assert!(reached_start); - assert_event_matches_msg(&events[0].clone().into(), "world"); - assert_event_matches_msg(&events[1].clone().into(), "hello"); + assert_event_matches_msg(&events[0], "world"); + assert_event_matches_msg(&events[1], "hello"); assert_eq!(events.len(), 2); assert!(room_stream.is_empty()); @@ -390,9 +376,9 @@ async fn test_backpaginate_multiple_iterations() { assert_eq!(num_iterations, 2); assert!(global_reached_start); - assert_event_matches_msg(&global_events[0].clone().into(), "world"); - assert_event_matches_msg(&global_events[1].clone().into(), "hello"); - assert_event_matches_msg(&global_events[2].clone().into(), "oh well"); + assert_event_matches_msg(&global_events[0], "world"); + assert_event_matches_msg(&global_events[1], "hello"); + assert_event_matches_msg(&global_events[2], "oh well"); assert_eq!(global_events.len(), 3); // And next time I'll open the room, I'll get the events in the right order. @@ -581,7 +567,7 @@ async fn test_backpaginating_without_token() { assert!(reached_start); // And we get notified about the new event. - assert_event_matches_msg(&events[0].clone().into(), "hi"); + assert_event_matches_msg(&events[0], "hi"); assert_eq!(events.len(), 1); assert!(room_stream.is_empty()); diff --git a/testing/matrix-sdk-integration-testing/src/tests/room.rs b/testing/matrix-sdk-integration-testing/src/tests/room.rs index bcb7cd712..c9b52e4d8 100644 --- a/testing/matrix-sdk-integration-testing/src/tests/room.rs +++ b/testing/matrix-sdk-integration-testing/src/tests/room.rs @@ -3,17 +3,14 @@ use std::time::Duration; use anyhow::Result; use assert_matches2::{assert_let, assert_matches}; use matrix_sdk::{ - deserialized_responses::TimelineEvent, room::MessagesOptions, ruma::{ api::client::room::create_room::v3::Request as CreateRoomRequest, assign, event_id, - events::{ - room::message::{MessageType, RoomMessageEventContent}, - AnyMessageLikeEvent, AnyStateEvent, AnyTimelineEvent, - }, + events::{room::message::RoomMessageEventContent, AnyStateEvent, AnyTimelineEvent}, uint, }, + test_utils::assert_event_matches_msg, RoomState, }; use tokio::{spawn, time::sleep}; @@ -199,15 +196,3 @@ async fn test_event_with_context() -> Result<()> { Ok(()) } - -// TODO: Try to avoid duplication with the other (almost) copy of this function? -// This can't go straight into `matrix-sdk-test` because it needs to depend on -// `matrix-sdk` first. -#[track_caller] -fn assert_event_matches_msg(event: &TimelineEvent, expected: &str) { - let event = event.event.deserialize().unwrap(); - assert_let!(AnyTimelineEvent::MessageLike(AnyMessageLikeEvent::RoomMessage(message)) = event); - let message = message.as_original().unwrap(); - assert_let!(MessageType::Text(text) = &message.content.msgtype); - assert_eq!(text.body, expected); -}