diff --git a/Cargo.lock b/Cargo.lock index 29a6cc56c..547f4e6ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3241,6 +3241,7 @@ dependencies = [ "assert_matches2", "assign", "eyeball-im", + "futures", "futures-core", "futures-util", "http", @@ -3249,11 +3250,13 @@ dependencies = [ "matrix-sdk-ui", "once_cell", "rand 0.8.5", + "reqwest", "serde_json", "stream_assert", "tempfile", "tokio", "tracing", + "wiremock", ] [[package]] diff --git a/crates/matrix-sdk-ui/src/room_list_service/room_list.rs b/crates/matrix-sdk-ui/src/room_list_service/room_list.rs index 259d507f8..649ebecf5 100644 --- a/crates/matrix-sdk-ui/src/room_list_service/room_list.rs +++ b/crates/matrix-sdk-ui/src/room_list_service/room_list.rs @@ -167,8 +167,9 @@ impl RoomList { } } -/// This function remembers the current state of the unfiltered room list, so it knows where all rooms are. -/// When the receiver is triggered, a Set operation for the room position is inserted to the stream. +/// This function remembers the current state of the unfiltered room list, so it +/// knows where all rooms are. When the receiver is triggered, a Set operation +/// for the room position is inserted to the stream. fn merge_stream_and_receiver( mut raw_current_values: Vector, raw_stream: impl Stream>>, diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index e6e96e869..7e64aea72 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -15,7 +15,6 @@ use std::{fmt, sync::Arc}; -use bytes::Bytes; use matrix_sdk_base::{store::StoreConfig, BaseClient}; use ruma::{ api::{client::discovery::discover_homeserver, error::FromHttpResponseError, MatrixVersion}, @@ -91,7 +90,6 @@ pub struct ClientBuilder { base_client: Option, #[cfg(feature = "e2e-encryption")] encryption_settings: EncryptionSettings, - response_preprocessor: Option, &mut http::Response)>, } impl ClientBuilder { @@ -109,7 +107,6 @@ impl ClientBuilder { base_client: None, #[cfg(feature = "e2e-encryption")] encryption_settings: Default::default(), - response_preprocessor: None, } } @@ -329,14 +326,6 @@ impl ClientBuilder { self } - pub fn response_preprocessor( - mut self, - response_preprocessor: fn(&http::Request, &mut http::Response), - ) -> Self { - self.response_preprocessor = Some(response_preprocessor); - self - } - /// Create a [`Client`] with the options set on this builder. /// /// # Errors @@ -372,11 +361,7 @@ impl ClientBuilder { BaseClient::with_store_config(build_store_config(self.store_config).await?) }; - let http_client = HttpClient::new( - inner_http_client.clone(), - self.request_config, - self.response_preprocessor, - ); + let http_client = HttpClient::new(inner_http_client.clone(), self.request_config); #[cfg(feature = "experimental-oidc")] let mut authentication_server_info = None; diff --git a/crates/matrix-sdk/src/http_client/mod.rs b/crates/matrix-sdk/src/http_client/mod.rs index 0c6bee66c..72390e427 100644 --- a/crates/matrix-sdk/src/http_client/mod.rs +++ b/crates/matrix-sdk/src/http_client/mod.rs @@ -49,21 +49,11 @@ pub(crate) struct HttpClient { pub(crate) inner: reqwest::Client, pub(crate) request_config: RequestConfig, next_request_id: Arc, - response_preprocessor: Option, &mut http::Response)>, } impl HttpClient { - pub(crate) fn new( - inner: reqwest::Client, - request_config: RequestConfig, - response_preprocessor: Option, &mut http::Response)>, - ) -> Self { - HttpClient { - inner, - request_config, - next_request_id: AtomicU64::new(0).into(), - response_preprocessor, - } + pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self { + HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() } } fn get_request_id(&self) -> String { diff --git a/crates/matrix-sdk/src/http_client/native.rs b/crates/matrix-sdk/src/http_client/native.rs index f3892a663..54955104f 100644 --- a/crates/matrix-sdk/src/http_client/native.rs +++ b/crates/matrix-sdk/src/http_client/native.rs @@ -92,10 +92,9 @@ impl HttpClient { } }; - let mut response = - send_request(&self.inner, &request, config.timeout, send_progress) - .await - .map_err(error_type)?; + let response = send_request(&self.inner, &request, config.timeout, send_progress) + .await + .map_err(error_type)?; let status_code = response.status(); let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); @@ -116,10 +115,6 @@ impl HttpClient { } } - if let Some(preprocessor) = self.response_preprocessor { - preprocessor(&request, &mut response); - } - R::IncomingResponse::try_from_http_response(response) .map_err(|e| error_type(HttpError::from(e))) } diff --git a/testing/matrix-sdk-integration-testing/Cargo.toml b/testing/matrix-sdk-integration-testing/Cargo.toml index 72c35d606..1971999dc 100644 --- a/testing/matrix-sdk-integration-testing/Cargo.toml +++ b/testing/matrix-sdk-integration-testing/Cargo.toml @@ -11,6 +11,7 @@ assert_matches2 = { workspace = true } anyhow = { workspace = true } assign = "1" eyeball-im = { workspace = true } +futures = { version = "0.3.29", features = ["executor"] } futures-core = { workspace = true } futures-util = { workspace = true } http = "0.2.11" @@ -20,7 +21,9 @@ matrix-sdk-test = { workspace = true } once_cell = { workspace = true } rand = { workspace = true } stream_assert = "0.1.1" +reqwest = "0.11.20" serde_json = "1.0.108" tempfile = "3.3.0" tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } tracing = { workspace = true } +wiremock = "0.5.21" diff --git a/testing/matrix-sdk-integration-testing/src/helpers.rs b/testing/matrix-sdk-integration-testing/src/helpers.rs index 1596a245b..8b71f5136 100644 --- a/testing/matrix-sdk-integration-testing/src/helpers.rs +++ b/testing/matrix-sdk-integration-testing/src/helpers.rs @@ -9,7 +9,6 @@ use std::{ use anyhow::Result; use assign::assign; use matrix_sdk::{ - bytes::Bytes, config::{RequestConfig, SyncSettings}, encryption::EncryptionSettings, ruma::api::client::{account::register::v3::Request as RegistrationRequest, uiaa}, @@ -26,7 +25,7 @@ pub struct TestClientBuilder { username: String, use_sqlite: bool, encryption_settings: EncryptionSettings, - response_preprocessor: Option, &mut http::Response)>, + http_proxy: Option, } impl TestClientBuilder { @@ -35,7 +34,7 @@ impl TestClientBuilder { username: username.into(), use_sqlite: false, encryption_settings: Default::default(), - response_preprocessor: None, + http_proxy: None, } } @@ -55,11 +54,8 @@ impl TestClientBuilder { self } - pub fn response_preprocessor( - mut self, - r: fn(&http::Request, &mut http::Response), - ) -> Self { - self.response_preprocessor = Some(r); + pub fn http_proxy(mut self, proxy: String) -> Self { + self.http_proxy = Some(proxy); self } @@ -83,8 +79,8 @@ impl TestClientBuilder { .with_encryption_settings(self.encryption_settings) .request_config(RequestConfig::short_retry()); - if let Some(response_preprocessor) = self.response_preprocessor { - client_builder = client_builder.response_preprocessor(response_preprocessor); + if let Some(proxy) = self.http_proxy { + client_builder = client_builder.proxy(proxy); } let client = if self.use_sqlite { diff --git a/testing/matrix-sdk-integration-testing/src/tests/sliding_sync/room.rs b/testing/matrix-sdk-integration-testing/src/tests/sliding_sync/room.rs index ce4be07b5..6042d699a 100644 --- a/testing/matrix-sdk-integration-testing/src/tests/sliding_sync/room.rs +++ b/testing/matrix-sdk-integration-testing/src/tests/sliding_sync/room.rs @@ -11,16 +11,14 @@ use futures_util::{pin_mut, StreamExt as _}; use matrix_sdk::{ bytes::Bytes, config::SyncSettings, + reqwest, ruma::{ - api::{ - client::{ - receipt::create_receipt::v3::ReceiptType, - room::create_room::v3::{Request as CreateRoomRequest, RoomPreset}, - sync::sync_events::v4::{ - AccountDataConfig, E2EEConfig, ReceiptsConfig, ToDeviceConfig, - }, + api::client::{ + receipt::create_receipt::v3::ReceiptType, + room::create_room::v3::{Request as CreateRoomRequest, RoomPreset}, + sync::sync_events::v4::{ + AccountDataConfig, E2EEConfig, ReceiptsConfig, ToDeviceConfig, }, - IncomingResponse, }, assign, events::{ @@ -40,7 +38,8 @@ use tokio::{ sync::Mutex, time::{sleep, timeout}, }; -use tracing::{error, warn}; +use tracing::{error, info, warn}; +use wiremock::{matchers::AnyMatcher, Mock, MockServer}; use crate::helpers::TestClientBuilder; @@ -547,16 +546,9 @@ async fn test_room_notification_count() -> Result<()> { Ok(()) } -struct DelayedDecryption { - active: bool, - queued: Vec, -} -static DELAYED_DECRYPTION: StdMutex = - StdMutex::new(DelayedDecryption { active: true, queued: Vec::new() }); - -fn delayed_decryption(request: &http::Request, response: &mut http::Response) { - dbg!(request.uri()); - let Ok(mut json) = serde_json::from_slice::(response.body()) else { +static DROP_TODEVICE: StdMutex = StdMutex::new(true); +fn drop_todevice(response: &mut Bytes) { + let Ok(mut json) = serde_json::from_slice::(response) else { return; }; let Some(extensions) = json.get_mut("extensions").and_then(|e| e.as_object_mut()) else { @@ -565,22 +557,80 @@ fn delayed_decryption(request: &http::Request, response: &mut http::Respo let Some(to_device) = extensions.remove("to_device") else { return; }; - let d = DELAYED_DECRYPTION.lock().unwrap(); - if d.active { - *response.body_mut() = serde_json::to_vec(&json).unwrap().into(); + if *DROP_TODEVICE.lock().unwrap() { + info!("Dropping to_device: {to_device}"); + *response = serde_json::to_vec(&json).unwrap().into(); return; - } else { - dbg!(to_device); - // d.queued.push(to_device); + } +} + +struct CustomResponder { + client: reqwest::Client, + response_preprocessor: fn(&mut Bytes), +} +impl CustomResponder { + fn new(response_preprocessor: fn(&mut Bytes)) -> Self { + Self { client: reqwest::Client::new(), response_preprocessor } + } +} +impl wiremock::Respond for CustomResponder { + fn respond(&self, request: &wiremock::Request) -> wiremock::ResponseTemplate { + let mut req = self.client.request( + request.method.to_string().parse().expect("All methods exist"), + request.url.clone(), + ); + for header in &request.headers { + for value in header.1 { + req = req.header(header.0.to_string(), value.to_string()); + } + } + req = req.body(request.body.clone()); + + // Run await inside of non-async fn by spawning a new thread and creating a new + // runtime. Is there a better way? + let response_preprocessor = self.response_preprocessor; + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let response = timeout(Duration::from_secs(2), req.send()).await; + + if let Ok(Ok(response)) = response { + let mut r = wiremock::ResponseTemplate::new(u16::from(response.status())); + for header in response.headers() { + if header.0 == "Content-Length" { + continue; + } + r = r.append_header(header.0.clone(), header.1.clone()); + } + + let mut bytes = response.bytes().await.unwrap_or_default(); + response_preprocessor(&mut bytes); + + r = r.set_body_bytes(bytes); + r + } else { + // Gateway timeout + let r = wiremock::ResponseTemplate::new(504); + r + } + }) + }) + .join() + .expect("Thread panicked") } } #[tokio::test] async fn test_delayed_decryption_latest_event() -> Result<()> { + let server = MockServer::start().await; + server + .register(Mock::given(AnyMatcher).respond_with(CustomResponder::new(drop_todevice))) + .await; + let alice = TestClientBuilder::new("alice".to_owned()) .randomize_username() .use_sqlite() - .response_preprocessor(delayed_decryption) + .http_proxy(server.uri()) .build() .await?; let bob = @@ -615,7 +665,7 @@ async fn test_delayed_decryption_latest_event() -> Result<()> { .await?; let s = sliding_alice.clone(); - tokio::task::spawn(async move { + spawn(async move { let stream = s.sync(); pin_mut!(stream); while let Some(up) = stream.next().await { @@ -624,7 +674,7 @@ async fn test_delayed_decryption_latest_event() -> Result<()> { }); let s = sliding_bob.clone(); - tokio::task::spawn(async move { + spawn(async move { let stream = s.sync(); pin_mut!(stream); while let Some(up) = stream.next().await { @@ -661,14 +711,14 @@ async fn test_delayed_decryption_latest_event() -> Result<()> { .entries_with_dynamic_adapters(10, alice.roominfo_update_receiver()); entries.set_filter(Box::new(new_filter_all(vec![]))); pin_mut!(stream); - dbg!(timeout(Duration::from_millis(100), stream.next()).await.unwrap()); + timeout(Duration::from_millis(100), stream.next()).await.unwrap(); assert!(timeout(Duration::from_millis(100), stream.next()).await.is_err()); assert!(matches!(alice_room.latest_event(), None)); - DELAYED_DECRYPTION.lock().unwrap().active = false; + *DROP_TODEVICE.lock().unwrap() = false; sleep(Duration::from_secs(1)).await; - dbg!(&alice_room.latest_event().unwrap().event().event); - dbg!(timeout(Duration::from_millis(100), stream.next()).await.unwrap()); + alice_room.latest_event().unwrap(); + timeout(Duration::from_millis(100), stream.next()).await.unwrap(); assert!(timeout(Duration::from_millis(100), stream.next()).await.is_err()); Ok(())