sliding_sync: Refactor delayed decryption test

Signed-off-by: Timo Kösters <timo@koesters.xyz>
This commit is contained in:
Timo Kösters
2024-01-23 15:27:28 +01:00
committed by Benjamin Bouvier
parent 50ed681a4e
commit 79f97504f8
8 changed files with 103 additions and 80 deletions

3
Cargo.lock generated
View File

@@ -3241,6 +3241,7 @@ dependencies = [
"assert_matches2",
"assign",
"eyeball-im",
"futures",
"futures-core",
"futures-util",
"http",
@@ -3249,11 +3250,13 @@ dependencies = [
"matrix-sdk-ui",
"once_cell",
"rand 0.8.5",
"reqwest",
"serde_json",
"stream_assert",
"tempfile",
"tokio",
"tracing",
"wiremock",
]
[[package]]

View File

@@ -167,8 +167,9 @@ impl RoomList {
}
}
/// This function remembers the current state of the unfiltered room list, so it knows where all rooms are.
/// When the receiver is triggered, a Set operation for the room position is inserted to the stream.
/// This function remembers the current state of the unfiltered room list, so it
/// knows where all rooms are. When the receiver is triggered, a Set operation
/// for the room position is inserted to the stream.
fn merge_stream_and_receiver(
mut raw_current_values: Vector<RoomListEntry>,
raw_stream: impl Stream<Item = Vec<VectorDiff<RoomListEntry>>>,

View File

@@ -15,7 +15,6 @@
use std::{fmt, sync::Arc};
use bytes::Bytes;
use matrix_sdk_base::{store::StoreConfig, BaseClient};
use ruma::{
api::{client::discovery::discover_homeserver, error::FromHttpResponseError, MatrixVersion},
@@ -91,7 +90,6 @@ pub struct ClientBuilder {
base_client: Option<BaseClient>,
#[cfg(feature = "e2e-encryption")]
encryption_settings: EncryptionSettings,
response_preprocessor: Option<fn(&http::Request<Bytes>, &mut http::Response<Bytes>)>,
}
impl ClientBuilder {
@@ -109,7 +107,6 @@ impl ClientBuilder {
base_client: None,
#[cfg(feature = "e2e-encryption")]
encryption_settings: Default::default(),
response_preprocessor: None,
}
}
@@ -329,14 +326,6 @@ impl ClientBuilder {
self
}
pub fn response_preprocessor(
mut self,
response_preprocessor: fn(&http::Request<Bytes>, &mut http::Response<Bytes>),
) -> Self {
self.response_preprocessor = Some(response_preprocessor);
self
}
/// Create a [`Client`] with the options set on this builder.
///
/// # Errors
@@ -372,11 +361,7 @@ impl ClientBuilder {
BaseClient::with_store_config(build_store_config(self.store_config).await?)
};
let http_client = HttpClient::new(
inner_http_client.clone(),
self.request_config,
self.response_preprocessor,
);
let http_client = HttpClient::new(inner_http_client.clone(), self.request_config);
#[cfg(feature = "experimental-oidc")]
let mut authentication_server_info = None;

View File

@@ -49,21 +49,11 @@ pub(crate) struct HttpClient {
pub(crate) inner: reqwest::Client,
pub(crate) request_config: RequestConfig,
next_request_id: Arc<AtomicU64>,
response_preprocessor: Option<fn(&http::Request<Bytes>, &mut http::Response<Bytes>)>,
}
impl HttpClient {
pub(crate) fn new(
inner: reqwest::Client,
request_config: RequestConfig,
response_preprocessor: Option<fn(&http::Request<Bytes>, &mut http::Response<Bytes>)>,
) -> Self {
HttpClient {
inner,
request_config,
next_request_id: AtomicU64::new(0).into(),
response_preprocessor,
}
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
}
fn get_request_id(&self) -> String {

View File

@@ -92,10 +92,9 @@ impl HttpClient {
}
};
let mut response =
send_request(&self.inner, &request, config.timeout, send_progress)
.await
.map_err(error_type)?;
let response = send_request(&self.inner, &request, config.timeout, send_progress)
.await
.map_err(error_type)?;
let status_code = response.status();
let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
@@ -116,10 +115,6 @@ impl HttpClient {
}
}
if let Some(preprocessor) = self.response_preprocessor {
preprocessor(&request, &mut response);
}
R::IncomingResponse::try_from_http_response(response)
.map_err(|e| error_type(HttpError::from(e)))
}

View File

@@ -11,6 +11,7 @@ assert_matches2 = { workspace = true }
anyhow = { workspace = true }
assign = "1"
eyeball-im = { workspace = true }
futures = { version = "0.3.29", features = ["executor"] }
futures-core = { workspace = true }
futures-util = { workspace = true }
http = "0.2.11"
@@ -20,7 +21,9 @@ matrix-sdk-test = { workspace = true }
once_cell = { workspace = true }
rand = { workspace = true }
stream_assert = "0.1.1"
reqwest = "0.11.20"
serde_json = "1.0.108"
tempfile = "3.3.0"
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] }
tracing = { workspace = true }
wiremock = "0.5.21"

View File

@@ -9,7 +9,6 @@ use std::{
use anyhow::Result;
use assign::assign;
use matrix_sdk::{
bytes::Bytes,
config::{RequestConfig, SyncSettings},
encryption::EncryptionSettings,
ruma::api::client::{account::register::v3::Request as RegistrationRequest, uiaa},
@@ -26,7 +25,7 @@ pub struct TestClientBuilder {
username: String,
use_sqlite: bool,
encryption_settings: EncryptionSettings,
response_preprocessor: Option<fn(&http::Request<Bytes>, &mut http::Response<Bytes>)>,
http_proxy: Option<String>,
}
impl TestClientBuilder {
@@ -35,7 +34,7 @@ impl TestClientBuilder {
username: username.into(),
use_sqlite: false,
encryption_settings: Default::default(),
response_preprocessor: None,
http_proxy: None,
}
}
@@ -55,11 +54,8 @@ impl TestClientBuilder {
self
}
pub fn response_preprocessor(
mut self,
r: fn(&http::Request<Bytes>, &mut http::Response<Bytes>),
) -> Self {
self.response_preprocessor = Some(r);
pub fn http_proxy(mut self, proxy: String) -> Self {
self.http_proxy = Some(proxy);
self
}
@@ -83,8 +79,8 @@ impl TestClientBuilder {
.with_encryption_settings(self.encryption_settings)
.request_config(RequestConfig::short_retry());
if let Some(response_preprocessor) = self.response_preprocessor {
client_builder = client_builder.response_preprocessor(response_preprocessor);
if let Some(proxy) = self.http_proxy {
client_builder = client_builder.proxy(proxy);
}
let client = if self.use_sqlite {

View File

@@ -11,16 +11,14 @@ use futures_util::{pin_mut, StreamExt as _};
use matrix_sdk::{
bytes::Bytes,
config::SyncSettings,
reqwest,
ruma::{
api::{
client::{
receipt::create_receipt::v3::ReceiptType,
room::create_room::v3::{Request as CreateRoomRequest, RoomPreset},
sync::sync_events::v4::{
AccountDataConfig, E2EEConfig, ReceiptsConfig, ToDeviceConfig,
},
api::client::{
receipt::create_receipt::v3::ReceiptType,
room::create_room::v3::{Request as CreateRoomRequest, RoomPreset},
sync::sync_events::v4::{
AccountDataConfig, E2EEConfig, ReceiptsConfig, ToDeviceConfig,
},
IncomingResponse,
},
assign,
events::{
@@ -40,7 +38,8 @@ use tokio::{
sync::Mutex,
time::{sleep, timeout},
};
use tracing::{error, warn};
use tracing::{error, info, warn};
use wiremock::{matchers::AnyMatcher, Mock, MockServer};
use crate::helpers::TestClientBuilder;
@@ -547,16 +546,9 @@ async fn test_room_notification_count() -> Result<()> {
Ok(())
}
struct DelayedDecryption {
active: bool,
queued: Vec<serde_json::Value>,
}
static DELAYED_DECRYPTION: StdMutex<DelayedDecryption> =
StdMutex::new(DelayedDecryption { active: true, queued: Vec::new() });
fn delayed_decryption(request: &http::Request<Bytes>, response: &mut http::Response<Bytes>) {
dbg!(request.uri());
let Ok(mut json) = serde_json::from_slice::<serde_json::Value>(response.body()) else {
static DROP_TODEVICE: StdMutex<bool> = StdMutex::new(true);
fn drop_todevice(response: &mut Bytes) {
let Ok(mut json) = serde_json::from_slice::<serde_json::Value>(response) else {
return;
};
let Some(extensions) = json.get_mut("extensions").and_then(|e| e.as_object_mut()) else {
@@ -565,22 +557,80 @@ fn delayed_decryption(request: &http::Request<Bytes>, response: &mut http::Respo
let Some(to_device) = extensions.remove("to_device") else {
return;
};
let d = DELAYED_DECRYPTION.lock().unwrap();
if d.active {
*response.body_mut() = serde_json::to_vec(&json).unwrap().into();
if *DROP_TODEVICE.lock().unwrap() {
info!("Dropping to_device: {to_device}");
*response = serde_json::to_vec(&json).unwrap().into();
return;
} else {
dbg!(to_device);
// d.queued.push(to_device);
}
}
struct CustomResponder {
client: reqwest::Client,
response_preprocessor: fn(&mut Bytes),
}
impl CustomResponder {
fn new(response_preprocessor: fn(&mut Bytes)) -> Self {
Self { client: reqwest::Client::new(), response_preprocessor }
}
}
impl wiremock::Respond for CustomResponder {
fn respond(&self, request: &wiremock::Request) -> wiremock::ResponseTemplate {
let mut req = self.client.request(
request.method.to_string().parse().expect("All methods exist"),
request.url.clone(),
);
for header in &request.headers {
for value in header.1 {
req = req.header(header.0.to_string(), value.to_string());
}
}
req = req.body(request.body.clone());
// Run await inside of non-async fn by spawning a new thread and creating a new
// runtime. Is there a better way?
let response_preprocessor = self.response_preprocessor;
std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let response = timeout(Duration::from_secs(2), req.send()).await;
if let Ok(Ok(response)) = response {
let mut r = wiremock::ResponseTemplate::new(u16::from(response.status()));
for header in response.headers() {
if header.0 == "Content-Length" {
continue;
}
r = r.append_header(header.0.clone(), header.1.clone());
}
let mut bytes = response.bytes().await.unwrap_or_default();
response_preprocessor(&mut bytes);
r = r.set_body_bytes(bytes);
r
} else {
// Gateway timeout
let r = wiremock::ResponseTemplate::new(504);
r
}
})
})
.join()
.expect("Thread panicked")
}
}
#[tokio::test]
async fn test_delayed_decryption_latest_event() -> Result<()> {
let server = MockServer::start().await;
server
.register(Mock::given(AnyMatcher).respond_with(CustomResponder::new(drop_todevice)))
.await;
let alice = TestClientBuilder::new("alice".to_owned())
.randomize_username()
.use_sqlite()
.response_preprocessor(delayed_decryption)
.http_proxy(server.uri())
.build()
.await?;
let bob =
@@ -615,7 +665,7 @@ async fn test_delayed_decryption_latest_event() -> Result<()> {
.await?;
let s = sliding_alice.clone();
tokio::task::spawn(async move {
spawn(async move {
let stream = s.sync();
pin_mut!(stream);
while let Some(up) = stream.next().await {
@@ -624,7 +674,7 @@ async fn test_delayed_decryption_latest_event() -> Result<()> {
});
let s = sliding_bob.clone();
tokio::task::spawn(async move {
spawn(async move {
let stream = s.sync();
pin_mut!(stream);
while let Some(up) = stream.next().await {
@@ -661,14 +711,14 @@ async fn test_delayed_decryption_latest_event() -> Result<()> {
.entries_with_dynamic_adapters(10, alice.roominfo_update_receiver());
entries.set_filter(Box::new(new_filter_all(vec![])));
pin_mut!(stream);
dbg!(timeout(Duration::from_millis(100), stream.next()).await.unwrap());
timeout(Duration::from_millis(100), stream.next()).await.unwrap();
assert!(timeout(Duration::from_millis(100), stream.next()).await.is_err());
assert!(matches!(alice_room.latest_event(), None));
DELAYED_DECRYPTION.lock().unwrap().active = false;
*DROP_TODEVICE.lock().unwrap() = false;
sleep(Duration::from_secs(1)).await;
dbg!(&alice_room.latest_event().unwrap().event().event);
dbg!(timeout(Duration::from_millis(100), stream.next()).await.unwrap());
alice_room.latest_event().unwrap();
timeout(Duration::from_millis(100), stream.next()).await.unwrap();
assert!(timeout(Duration::from_millis(100), stream.next()).await.is_err());
Ok(())