refactor(sdk): Use Cache instead of TtlCache for OAuth 2.0 server metadata

We also cache the data every time that `OAuth::server_metadata()` is
called, in the hope of avoiding some requests with
`OAuth::cached_server_metadata()`.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2026-04-22 13:38:08 +02:00
committed by Ivan Enderlin
parent 6777907f28
commit 53ffa347e4
4 changed files with 74 additions and 23 deletions

View File

@@ -174,7 +174,7 @@ use error::{
};
#[cfg(feature = "e2e-encryption")]
use matrix_sdk_base::crypto::types::qr_login::QrCodeData;
use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings};
use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings, ttl_cache::TtlValue};
#[cfg(feature = "e2e-encryption")]
use matrix_sdk_common::cross_process_lock::CrossProcessLockConfig;
use oauth2::{
@@ -225,7 +225,9 @@ use self::{
};
use super::{AuthData, SessionTokens};
use crate::{
Client, HttpError, RefreshTokenError, Result, client::SessionChange, executor::spawn,
Client, HttpError, RefreshTokenError, Result,
client::{SessionChange, caches::CachedValue},
executor::spawn,
utils::UrlOrQuery,
};
@@ -434,6 +436,9 @@ impl OAuth {
/// [`OAuth::server_metadata()`], and cache the response before returning
/// it.
///
/// The cache can be forced to be refreshed by calling
/// [`OAuth::server_metadata()`] instead.
///
/// In most cases during the authentication process, it is better to always
/// fetch the metadata from the server. This is provided for convenience for
/// cases where the client doesn't want to incur the extra time necessary to
@@ -444,19 +449,45 @@ impl OAuth {
pub async fn cached_server_metadata(
&self,
) -> Result<AuthorizationServerMetadata, OAuthDiscoveryError> {
const CACHE_KEY: &str = "SERVER_METADATA";
let server_metadata_cache = &self.client.inner.caches.server_metadata;
let mut cache = self.client.inner.caches.server_metadata.lock().await;
if let CachedValue::Cached(metadata) = server_metadata_cache.value() {
if metadata.has_expired() {
debug!("spawning task to refresh OAuth 2.0 server metadata cache");
let metadata = if let Some(metadata) = cache.get(CACHE_KEY) {
metadata
} else {
let server_metadata = self.server_metadata().await?;
cache.insert(CACHE_KEY.to_owned(), server_metadata.clone());
server_metadata
let oauth = self.clone();
self.client.task_monitor().spawn_finite_task(
"refresh OAuth 2.0 server metadata cache",
async move {
if let Err(error) = oauth.server_metadata().await {
warn!("failed to refresh OAuth 2.0 server metadata cache: {error}");
}
},
);
}
return Ok(metadata.into_data());
}
let _server_metadata_guard = match server_metadata_cache.refresh_lock.try_lock() {
Ok(guard) => guard,
Err(_) => {
// There is already a refresh in progress, wait for it to finish.
let guard = server_metadata_cache.refresh_lock.lock().await;
// Reuse the data if it was cached and it hasn't expired.
if let CachedValue::Cached(value) = server_metadata_cache.value()
&& !value.has_expired()
{
return Ok(value.into_data());
}
// The data wasn't cached or has expired, we need to make another request.
guard
}
};
Ok(metadata)
self.server_metadata().await
}
/// Fetch the OAuth 2.0 authorization server metadata of the homeserver.
@@ -497,6 +528,10 @@ impl OAuth {
metadata.validate_urls()?;
}
// Always refresh the cache when we got a new value to possibly avoid extra
// calls in cached_server_metadata.
self.client.inner.caches.server_metadata.set_value(TtlValue::new(metadata.clone()));
Ok(metadata)
}

View File

@@ -1,6 +1,8 @@
use std::time::Duration;
use anyhow::Context as _;
use assert_matches::assert_matches;
use matrix_sdk_base::store::RoomLoadSettings;
use matrix_sdk_base::{sleep::sleep, store::RoomLoadSettings, ttl_cache::TtlValue};
use matrix_sdk_test::async_test;
use oauth2::{ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope};
use ruma::{
@@ -20,6 +22,7 @@ use crate::{
AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError,
error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError},
},
client::caches::CachedValue,
test_utils::{
client::{
MockClientBuilder, mock_prev_session_tokens_with_refresh,
@@ -678,21 +681,36 @@ async fn test_server_metadata_cache() {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1).mount().await;
oauth_server.mock_server_metadata().ok().expect(2).mount().await;
let client = server.client_builder().logged_in_with_oauth().build().await;
let oauth = client.oauth();
// The cache should not contain the entry.
assert!(!client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA"));
assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::NotSet);
oauth.cached_server_metadata().await.expect("We should be able to fetch the server metadata");
// Check that the server metadata has been inserted into the cache.
assert!(client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA"));
assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::Cached(_));
// Another call doesn't make another request for the metadata.
oauth.cached_server_metadata().await.expect("We should be able to fetch the server_metadata");
let metadata = oauth
.cached_server_metadata()
.await
.expect("We should be able to fetch the server_metadata");
// Force an expiry of the cached data.
let mut ttl_value = TtlValue::new(metadata);
ttl_value.expire();
client.inner.caches.server_metadata.set_value(ttl_value);
// Call the method to trigger a cache refresh background task.
oauth.cached_server_metadata().await.expect("We should be able to fetch the server metadata");
// We wait for the task to finish, the endpoint should have been called again.
sleep(Duration::from_secs(1)).await;
assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::Cached(value) if !value.has_expired());
}
#[async_test]

View File

@@ -14,7 +14,7 @@
use std::sync::Arc;
use matrix_sdk_base::{store::WellKnownResponse, ttl_cache::TtlCache};
use matrix_sdk_base::store::WellKnownResponse;
use matrix_sdk_common::{locks::Mutex, ttl_cache::TtlValue};
use ruma::api::{
SupportedVersions,
@@ -39,7 +39,8 @@ pub(crate) struct ClientCaches {
pub(crate) supported_versions: Cache<SupportedVersions>,
/// Well-known information.
pub(super) well_known: Cache<Option<WellKnownResponse>>,
pub(crate) server_metadata: AsyncMutex<TtlCache<String, AuthorizationServerMetadata>>,
/// OAuth 2.0 server metadata.
pub(crate) server_metadata: Cache<AuthorizationServerMetadata>,
/// Homeserver capabilities.
pub(crate) homeserver_capabilities: Cache<Capabilities>,
}

View File

@@ -40,10 +40,7 @@ use matrix_sdk_base::{
sync::{Notification, RoomUpdates},
task_monitor::TaskMonitor,
};
use matrix_sdk_common::{
cross_process_lock::CrossProcessLockConfig,
ttl_cache::{TtlCache, TtlValue},
};
use matrix_sdk_common::{cross_process_lock::CrossProcessLockConfig, ttl_cache::TtlValue};
#[cfg(feature = "e2e-encryption")]
use ruma::events::{InitialStateEvent, room::encryption::RoomEncryptionEventContent};
use ruma::{
@@ -427,7 +424,7 @@ impl ClientInner {
let caches = ClientCaches {
supported_versions: Cache::with_value(supported_versions),
well_known: Cache::with_value(well_known),
server_metadata: Mutex::new(TtlCache::new()),
server_metadata: Cache::new(),
homeserver_capabilities: Cache::new(),
};