From 19fcae4e0b8f58f760b72df425a720b43f682f09 Mon Sep 17 00:00:00 2001 From: Benjamin Bouvier Date: Thu, 4 Jul 2024 19:21:00 +0200 Subject: [PATCH] sdk: add a way to reset the in-memory and on-disk caches for server capabilities This required changing the OnceCell to RwLock>, because a OnceCell can't be reset to another value. --- bindings/matrix-sdk-ffi/src/client.rs | 9 ++ crates/matrix-sdk/src/client/mod.rs | 179 ++++++++++++++++++----- crates/matrix-sdk/src/matrix_auth/mod.rs | 4 +- 3 files changed, 156 insertions(+), 36 deletions(-) diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index 22282718e..1a0384974 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -442,6 +442,15 @@ impl Client { let http_client = self.inner.http_client(); Ok(http_client.get(url).send().await?.text().await?) } + + /// Empty the server version and unstable features cache. + /// + /// Since the SDK caches server capabilities (versions and unstable + /// features), it's possible to have a stale entry in the cache. This + /// functions makes it possible to force reset it. + pub async fn reset_server_capabilities(&self) -> Result<(), ClientError> { + Ok(self.inner.reset_server_capabilities().await?) + } } impl Client { diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 7d4de778b..4788503c8 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -239,10 +239,10 @@ pub(crate) struct ClientInner { base_client: BaseClient, /// The Matrix versions the server supports (well-known ones only) - server_versions: OnceCell>, + server_versions: RwLock>>, /// The unstable features and their on/off state on the server - unstable_features: OnceCell>, + unstable_features: RwLock>>, /// Collection of locks individual client methods might want to use, either /// to ensure that only a single call to a method happens at once or to @@ -324,8 +324,8 @@ impl ClientInner { http_client, base_client, locks: Default::default(), - server_versions: OnceCell::new_with(server_versions), - unstable_features: OnceCell::new_with(unstable_features), + server_versions: RwLock::new(server_versions), + unstable_features: RwLock::new(unstable_features), typing_notice_times: Default::default(), event_handlers: Default::default(), notification_handlers: Default::default(), @@ -1430,7 +1430,7 @@ impl Client { config, homeserver, access_token.as_deref(), - self.server_versions().await?, + &self.server_versions().await?, send_progress, ) .await @@ -1509,20 +1509,25 @@ impl Client { self.fetch_and_cache_server_capabilities().await } - pub(crate) async fn server_versions(&self) -> HttpResult<&[MatrixVersion]> { - Ok(self - .inner - .server_versions - .get_or_try_init(|| async { - let (versions, unstable_features) = - self.load_or_fetch_server_capabilities().await?; + pub(crate) async fn server_versions(&self) -> HttpResult> { + if let Some(server_versions) = self.inner.server_versions.read().await.as_ref() { + return Ok(server_versions.clone()); + } - // Fill the unstable features, while we're at it. - self.inner.unstable_features.get_or_init(|| async { unstable_features }).await; + let mut guard = self.inner.server_versions.write().await; + if let Some(server_versions) = guard.as_ref() { + return Ok(server_versions.clone()); + } - HttpResult::Ok(versions) - }) - .await?) + let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?; + + *guard = Some(versions.clone()); + drop(guard); + + // Fill the unstable features, while we're at it. + *self.inner.unstable_features.write().await = Some(unstable_features); + + Ok(versions) } /// Get unstable features from by fetching from the server or the cache. @@ -1539,20 +1544,39 @@ impl Client { /// let msc_x = unstable_features.get("msc_x").unwrap_or(&false); /// # anyhow::Ok(()) }; /// ``` - pub async fn unstable_features(&self) -> HttpResult<&BTreeMap> { - Ok(self - .inner - .unstable_features - .get_or_try_init(|| async { - let (versions, unstable_features) = - self.load_or_fetch_server_capabilities().await?; + pub async fn unstable_features(&self) -> HttpResult> { + if let Some(unstable_features) = self.inner.unstable_features.read().await.as_ref() { + return Ok(unstable_features.clone()); + } - // Fill the versions, while we're at it. - self.inner.server_versions.get_or_init(|| async { versions }).await; + let mut guard = self.inner.unstable_features.write().await; + if let Some(unstable_features) = guard.as_ref() { + return Ok(unstable_features.clone()); + } - HttpResult::Ok(unstable_features) - }) - .await?) + let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?; + + *guard = Some(unstable_features.clone()); + drop(guard); + + // Fill the versions, while we're at it. + *self.inner.server_versions.write().await = Some(versions); + + Ok(unstable_features) + } + + /// Empty the server version and unstable features cache. + /// + /// Since the SDK caches server capabilities (versions and unstable + /// features), it's possible to have a stale entry in the cache. This + /// functions makes it possible to force reset it. + pub async fn reset_server_capabilities(&self) -> Result<()> { + // Empty the in-memory caches. + *self.inner.server_versions.write().await = None; + *self.inner.unstable_features.write().await = None; + + // Empty the store cache. + Ok(self.store().remove_kv_data(StateStoreDataKey::ServerCapabilities).await?) } /// Check whether MSC 4028 is enabled on the homeserver. @@ -2161,8 +2185,8 @@ impl Client { sliding_sync_proxy, self.inner.http_client.clone(), self.inner.base_client.clone_with_in_memory_state_store(), - self.inner.server_versions.get().cloned(), - self.inner.unstable_features.get().cloned(), + self.inner.server_versions.read().await.clone(), + self.inner.unstable_features.read().await.clone(), self.inner.respect_login_well_known, self.inner.event_cache.clone(), self.inner.send_queue_data.clone(), @@ -2236,7 +2260,10 @@ pub(crate) mod tests { use std::{sync::Arc, time::Duration}; use assert_matches::assert_matches; - use matrix_sdk_base::RoomState; + use matrix_sdk_base::{ + store::{MemoryStore, StoreConfig}, + RoomState, + }; use matrix_sdk_test::{ async_test, test_json, JoinedRoomBuilder, StateTestEvent, SyncResponseBuilder, DEFAULT_TEST_ROOM_ID, @@ -2245,8 +2272,8 @@ pub(crate) mod tests { wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); use ruma::{ - events::ignored_user_list::IgnoredUserListEventContent, owned_room_id, room_id, RoomId, - UserId, + api::MatrixVersion, events::ignored_user_list::IgnoredUserListEventContent, owned_room_id, + room_id, RoomId, ServerName, UserId, }; use url::Url; use wiremock::{ @@ -2609,4 +2636,88 @@ pub(crate) mod tests { Arc::strong_count(&client.unwrap().inner) ); } + + #[async_test] + async fn test_server_capabilities_caching() { + let server = MockServer::start().await; + let server_url = server.uri(); + let domain = server_url.strip_prefix("http://").unwrap(); + let server_name = <&ServerName>::try_from(domain).unwrap(); + + Mock::given(method("GET")) + .and(path("/.well-known/matrix/client")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()), + "application/json", + )) + .named("well known mock") + .expect(2) + .mount(&server) + .await; + + let versions_mock = Mock::given(method("GET")) + .and(path("/_matrix/client/versions")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS)) + .named("first versions mock") + .expect(1) + .mount_as_scoped(&server) + .await; + + let memory_store = Arc::new(MemoryStore::new()); + let client = Client::builder() + .insecure_server_name_no_tls(server_name) + .store_config(StoreConfig::new().state_store(memory_store.clone())) + .build() + .await + .unwrap(); + + assert_eq!(client.server_versions().await.unwrap().len(), 1); + + // This second call hits the in-memory cache. + assert!(client + .server_versions() + .await + .unwrap() + .iter() + .any(|version| *version == MatrixVersion::V1_0)); + + drop(client); + + let client = Client::builder() + .insecure_server_name_no_tls(server_name) + .store_config(StoreConfig::new().state_store(memory_store.clone())) + .build() + .await + .unwrap(); + + // This third call hits the on-disk cache. + assert_eq!( + client.unstable_features().await.unwrap().get("org.matrix.e2e_cross_signing"), + Some(&true) + ); + + drop(versions_mock); + server.verify().await; + + // Now, reset the cache, and observe the endpoint being called again once. + client.reset_server_capabilities().await.unwrap(); + + Mock::given(method("GET")) + .and(path("/_matrix/client/versions")) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS)) + .expect(1) + .named("second versions mock") + .mount(&server) + .await; + + // Hits network again. + assert_eq!(client.server_versions().await.unwrap().len(), 1); + // Hits in-memory cache again. + assert!(client + .server_versions() + .await + .unwrap() + .iter() + .any(|version| *version == MatrixVersion::V1_0)); + } } diff --git a/crates/matrix-sdk/src/matrix_auth/mod.rs b/crates/matrix-sdk/src/matrix_auth/mod.rs index ec2ad6441..901f87a37 100644 --- a/crates/matrix-sdk/src/matrix_auth/mod.rs +++ b/crates/matrix-sdk/src/matrix_auth/mod.rs @@ -130,13 +130,13 @@ impl MatrixAuth { .try_into_http_request::>( homeserver.as_str(), SendAccessToken::None, - server_versions, + &server_versions, ) } else { sso_login::v3::Request::new(redirect_url.to_owned()).try_into_http_request::>( homeserver.as_str(), SendAccessToken::None, - server_versions, + &server_versions, ) };