diff --git a/crates/matrix-sdk/src/authentication/oauth/mod.rs b/crates/matrix-sdk/src/authentication/oauth/mod.rs index 0bfe8d007..d552dc369 100644 --- a/crates/matrix-sdk/src/authentication/oauth/mod.rs +++ b/crates/matrix-sdk/src/authentication/oauth/mod.rs @@ -174,7 +174,7 @@ use error::{ }; #[cfg(feature = "e2e-encryption")] use matrix_sdk_base::crypto::types::qr_login::QrCodeData; -use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings}; +use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings, ttl_cache::TtlValue}; #[cfg(feature = "e2e-encryption")] use matrix_sdk_common::cross_process_lock::CrossProcessLockConfig; use oauth2::{ @@ -225,7 +225,9 @@ use self::{ }; use super::{AuthData, SessionTokens}; use crate::{ - Client, HttpError, RefreshTokenError, Result, client::SessionChange, executor::spawn, + Client, HttpError, RefreshTokenError, Result, + client::{SessionChange, caches::CachedValue}, + executor::spawn, utils::UrlOrQuery, }; @@ -434,6 +436,9 @@ impl OAuth { /// [`OAuth::server_metadata()`], and cache the response before returning /// it. /// + /// The cache can be forced to be refreshed by calling + /// [`OAuth::server_metadata()`] instead. + /// /// In most cases during the authentication process, it is better to always /// fetch the metadata from the server. This is provided for convenience for /// cases where the client doesn't want to incur the extra time necessary to @@ -444,19 +449,45 @@ impl OAuth { pub async fn cached_server_metadata( &self, ) -> Result { - const CACHE_KEY: &str = "SERVER_METADATA"; + let server_metadata_cache = &self.client.inner.caches.server_metadata; - let mut cache = self.client.inner.caches.server_metadata.lock().await; + if let CachedValue::Cached(metadata) = server_metadata_cache.value() { + if metadata.has_expired() { + debug!("spawning task to refresh OAuth 2.0 server metadata cache"); - let metadata = if let Some(metadata) = cache.get(CACHE_KEY) { - metadata - } else { - let server_metadata = self.server_metadata().await?; - cache.insert(CACHE_KEY.to_owned(), server_metadata.clone()); - server_metadata + let oauth = self.clone(); + self.client.task_monitor().spawn_finite_task( + "refresh OAuth 2.0 server metadata cache", + async move { + if let Err(error) = oauth.server_metadata().await { + warn!("failed to refresh OAuth 2.0 server metadata cache: {error}"); + } + }, + ); + } + + return Ok(metadata.into_data()); + } + + let _server_metadata_guard = match server_metadata_cache.refresh_lock.try_lock() { + Ok(guard) => guard, + Err(_) => { + // There is already a refresh in progress, wait for it to finish. + let guard = server_metadata_cache.refresh_lock.lock().await; + + // Reuse the data if it was cached and it hasn't expired. + if let CachedValue::Cached(value) = server_metadata_cache.value() + && !value.has_expired() + { + return Ok(value.into_data()); + } + + // The data wasn't cached or has expired, we need to make another request. + guard + } }; - Ok(metadata) + self.server_metadata().await } /// Fetch the OAuth 2.0 authorization server metadata of the homeserver. @@ -497,6 +528,10 @@ impl OAuth { metadata.validate_urls()?; } + // Always refresh the cache when we got a new value to possibly avoid extra + // calls in cached_server_metadata. + self.client.inner.caches.server_metadata.set_value(TtlValue::new(metadata.clone())); + Ok(metadata) } diff --git a/crates/matrix-sdk/src/authentication/oauth/tests.rs b/crates/matrix-sdk/src/authentication/oauth/tests.rs index 574117b72..ba284f264 100644 --- a/crates/matrix-sdk/src/authentication/oauth/tests.rs +++ b/crates/matrix-sdk/src/authentication/oauth/tests.rs @@ -1,6 +1,8 @@ +use std::time::Duration; + use anyhow::Context as _; use assert_matches::assert_matches; -use matrix_sdk_base::store::RoomLoadSettings; +use matrix_sdk_base::{sleep::sleep, store::RoomLoadSettings, ttl_cache::TtlValue}; use matrix_sdk_test::async_test; use oauth2::{ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope}; use ruma::{ @@ -20,6 +22,7 @@ use crate::{ AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError, error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError}, }, + client::caches::CachedValue, test_utils::{ client::{ MockClientBuilder, mock_prev_session_tokens_with_refresh, @@ -678,21 +681,36 @@ async fn test_server_metadata_cache() { let server = MatrixMockServer::new().await; let oauth_server = server.oauth(); - oauth_server.mock_server_metadata().ok().expect(1).mount().await; + oauth_server.mock_server_metadata().ok().expect(2).mount().await; let client = server.client_builder().logged_in_with_oauth().build().await; let oauth = client.oauth(); // The cache should not contain the entry. - assert!(!client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA")); + assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::NotSet); oauth.cached_server_metadata().await.expect("We should be able to fetch the server metadata"); // Check that the server metadata has been inserted into the cache. - assert!(client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA")); + assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::Cached(_)); // Another call doesn't make another request for the metadata. - oauth.cached_server_metadata().await.expect("We should be able to fetch the server_metadata"); + let metadata = oauth + .cached_server_metadata() + .await + .expect("We should be able to fetch the server_metadata"); + + // Force an expiry of the cached data. + let mut ttl_value = TtlValue::new(metadata); + ttl_value.expire(); + client.inner.caches.server_metadata.set_value(ttl_value); + + // Call the method to trigger a cache refresh background task. + oauth.cached_server_metadata().await.expect("We should be able to fetch the server metadata"); + + // We wait for the task to finish, the endpoint should have been called again. + sleep(Duration::from_secs(1)).await; + assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::Cached(value) if !value.has_expired()); } #[async_test] diff --git a/crates/matrix-sdk/src/client/caches.rs b/crates/matrix-sdk/src/client/caches.rs index 431d13769..6d38deb1f 100644 --- a/crates/matrix-sdk/src/client/caches.rs +++ b/crates/matrix-sdk/src/client/caches.rs @@ -14,7 +14,7 @@ use std::sync::Arc; -use matrix_sdk_base::{store::WellKnownResponse, ttl_cache::TtlCache}; +use matrix_sdk_base::store::WellKnownResponse; use matrix_sdk_common::{locks::Mutex, ttl_cache::TtlValue}; use ruma::api::{ SupportedVersions, @@ -39,7 +39,8 @@ pub(crate) struct ClientCaches { pub(crate) supported_versions: Cache, /// Well-known information. pub(super) well_known: Cache>, - pub(crate) server_metadata: AsyncMutex>, + /// OAuth 2.0 server metadata. + pub(crate) server_metadata: Cache, /// Homeserver capabilities. pub(crate) homeserver_capabilities: Cache, } diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index bcbcf607f..86ad9fcf9 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -40,10 +40,7 @@ use matrix_sdk_base::{ sync::{Notification, RoomUpdates}, task_monitor::TaskMonitor, }; -use matrix_sdk_common::{ - cross_process_lock::CrossProcessLockConfig, - ttl_cache::{TtlCache, TtlValue}, -}; +use matrix_sdk_common::{cross_process_lock::CrossProcessLockConfig, ttl_cache::TtlValue}; #[cfg(feature = "e2e-encryption")] use ruma::events::{InitialStateEvent, room::encryption::RoomEncryptionEventContent}; use ruma::{ @@ -427,7 +424,7 @@ impl ClientInner { let caches = ClientCaches { supported_versions: Cache::with_value(supported_versions), well_known: Cache::with_value(well_known), - server_metadata: Mutex::new(TtlCache::new()), + server_metadata: Cache::new(), homeserver_capabilities: Cache::new(), };