mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-04-29 11:35:24 -04:00
refactor(sdk): Use Cache instead of TtlCache for OAuth 2.0 server metadata
We also cache the data every time that `OAuth::server_metadata()` is called, in the hope of avoiding some requests with `OAuth::cached_server_metadata()`. Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
committed by
Ivan Enderlin
parent
6777907f28
commit
53ffa347e4
@@ -174,7 +174,7 @@ use error::{
|
||||
};
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
use matrix_sdk_base::crypto::types::qr_login::QrCodeData;
|
||||
use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings};
|
||||
use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings, ttl_cache::TtlValue};
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
use matrix_sdk_common::cross_process_lock::CrossProcessLockConfig;
|
||||
use oauth2::{
|
||||
@@ -225,7 +225,9 @@ use self::{
|
||||
};
|
||||
use super::{AuthData, SessionTokens};
|
||||
use crate::{
|
||||
Client, HttpError, RefreshTokenError, Result, client::SessionChange, executor::spawn,
|
||||
Client, HttpError, RefreshTokenError, Result,
|
||||
client::{SessionChange, caches::CachedValue},
|
||||
executor::spawn,
|
||||
utils::UrlOrQuery,
|
||||
};
|
||||
|
||||
@@ -434,6 +436,9 @@ impl OAuth {
|
||||
/// [`OAuth::server_metadata()`], and cache the response before returning
|
||||
/// it.
|
||||
///
|
||||
/// The cache can be forced to be refreshed by calling
|
||||
/// [`OAuth::server_metadata()`] instead.
|
||||
///
|
||||
/// In most cases during the authentication process, it is better to always
|
||||
/// fetch the metadata from the server. This is provided for convenience for
|
||||
/// cases where the client doesn't want to incur the extra time necessary to
|
||||
@@ -444,19 +449,45 @@ impl OAuth {
|
||||
pub async fn cached_server_metadata(
|
||||
&self,
|
||||
) -> Result<AuthorizationServerMetadata, OAuthDiscoveryError> {
|
||||
const CACHE_KEY: &str = "SERVER_METADATA";
|
||||
let server_metadata_cache = &self.client.inner.caches.server_metadata;
|
||||
|
||||
let mut cache = self.client.inner.caches.server_metadata.lock().await;
|
||||
if let CachedValue::Cached(metadata) = server_metadata_cache.value() {
|
||||
if metadata.has_expired() {
|
||||
debug!("spawning task to refresh OAuth 2.0 server metadata cache");
|
||||
|
||||
let metadata = if let Some(metadata) = cache.get(CACHE_KEY) {
|
||||
metadata
|
||||
} else {
|
||||
let server_metadata = self.server_metadata().await?;
|
||||
cache.insert(CACHE_KEY.to_owned(), server_metadata.clone());
|
||||
server_metadata
|
||||
let oauth = self.clone();
|
||||
self.client.task_monitor().spawn_finite_task(
|
||||
"refresh OAuth 2.0 server metadata cache",
|
||||
async move {
|
||||
if let Err(error) = oauth.server_metadata().await {
|
||||
warn!("failed to refresh OAuth 2.0 server metadata cache: {error}");
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
return Ok(metadata.into_data());
|
||||
}
|
||||
|
||||
let _server_metadata_guard = match server_metadata_cache.refresh_lock.try_lock() {
|
||||
Ok(guard) => guard,
|
||||
Err(_) => {
|
||||
// There is already a refresh in progress, wait for it to finish.
|
||||
let guard = server_metadata_cache.refresh_lock.lock().await;
|
||||
|
||||
// Reuse the data if it was cached and it hasn't expired.
|
||||
if let CachedValue::Cached(value) = server_metadata_cache.value()
|
||||
&& !value.has_expired()
|
||||
{
|
||||
return Ok(value.into_data());
|
||||
}
|
||||
|
||||
// The data wasn't cached or has expired, we need to make another request.
|
||||
guard
|
||||
}
|
||||
};
|
||||
|
||||
Ok(metadata)
|
||||
self.server_metadata().await
|
||||
}
|
||||
|
||||
/// Fetch the OAuth 2.0 authorization server metadata of the homeserver.
|
||||
@@ -497,6 +528,10 @@ impl OAuth {
|
||||
metadata.validate_urls()?;
|
||||
}
|
||||
|
||||
// Always refresh the cache when we got a new value to possibly avoid extra
|
||||
// calls in cached_server_metadata.
|
||||
self.client.inner.caches.server_metadata.set_value(TtlValue::new(metadata.clone()));
|
||||
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use assert_matches::assert_matches;
|
||||
use matrix_sdk_base::store::RoomLoadSettings;
|
||||
use matrix_sdk_base::{sleep::sleep, store::RoomLoadSettings, ttl_cache::TtlValue};
|
||||
use matrix_sdk_test::async_test;
|
||||
use oauth2::{ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope};
|
||||
use ruma::{
|
||||
@@ -20,6 +22,7 @@ use crate::{
|
||||
AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError,
|
||||
error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError},
|
||||
},
|
||||
client::caches::CachedValue,
|
||||
test_utils::{
|
||||
client::{
|
||||
MockClientBuilder, mock_prev_session_tokens_with_refresh,
|
||||
@@ -678,21 +681,36 @@ async fn test_server_metadata_cache() {
|
||||
let server = MatrixMockServer::new().await;
|
||||
|
||||
let oauth_server = server.oauth();
|
||||
oauth_server.mock_server_metadata().ok().expect(1).mount().await;
|
||||
oauth_server.mock_server_metadata().ok().expect(2).mount().await;
|
||||
|
||||
let client = server.client_builder().logged_in_with_oauth().build().await;
|
||||
let oauth = client.oauth();
|
||||
|
||||
// The cache should not contain the entry.
|
||||
assert!(!client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA"));
|
||||
assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::NotSet);
|
||||
|
||||
oauth.cached_server_metadata().await.expect("We should be able to fetch the server metadata");
|
||||
|
||||
// Check that the server metadata has been inserted into the cache.
|
||||
assert!(client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA"));
|
||||
assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::Cached(_));
|
||||
|
||||
// Another call doesn't make another request for the metadata.
|
||||
oauth.cached_server_metadata().await.expect("We should be able to fetch the server_metadata");
|
||||
let metadata = oauth
|
||||
.cached_server_metadata()
|
||||
.await
|
||||
.expect("We should be able to fetch the server_metadata");
|
||||
|
||||
// Force an expiry of the cached data.
|
||||
let mut ttl_value = TtlValue::new(metadata);
|
||||
ttl_value.expire();
|
||||
client.inner.caches.server_metadata.set_value(ttl_value);
|
||||
|
||||
// Call the method to trigger a cache refresh background task.
|
||||
oauth.cached_server_metadata().await.expect("We should be able to fetch the server metadata");
|
||||
|
||||
// We wait for the task to finish, the endpoint should have been called again.
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
assert_matches!(client.inner.caches.server_metadata.value(), CachedValue::Cached(value) if !value.has_expired());
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk_base::{store::WellKnownResponse, ttl_cache::TtlCache};
|
||||
use matrix_sdk_base::store::WellKnownResponse;
|
||||
use matrix_sdk_common::{locks::Mutex, ttl_cache::TtlValue};
|
||||
use ruma::api::{
|
||||
SupportedVersions,
|
||||
@@ -39,7 +39,8 @@ pub(crate) struct ClientCaches {
|
||||
pub(crate) supported_versions: Cache<SupportedVersions>,
|
||||
/// Well-known information.
|
||||
pub(super) well_known: Cache<Option<WellKnownResponse>>,
|
||||
pub(crate) server_metadata: AsyncMutex<TtlCache<String, AuthorizationServerMetadata>>,
|
||||
/// OAuth 2.0 server metadata.
|
||||
pub(crate) server_metadata: Cache<AuthorizationServerMetadata>,
|
||||
/// Homeserver capabilities.
|
||||
pub(crate) homeserver_capabilities: Cache<Capabilities>,
|
||||
}
|
||||
|
||||
@@ -40,10 +40,7 @@ use matrix_sdk_base::{
|
||||
sync::{Notification, RoomUpdates},
|
||||
task_monitor::TaskMonitor,
|
||||
};
|
||||
use matrix_sdk_common::{
|
||||
cross_process_lock::CrossProcessLockConfig,
|
||||
ttl_cache::{TtlCache, TtlValue},
|
||||
};
|
||||
use matrix_sdk_common::{cross_process_lock::CrossProcessLockConfig, ttl_cache::TtlValue};
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
use ruma::events::{InitialStateEvent, room::encryption::RoomEncryptionEventContent};
|
||||
use ruma::{
|
||||
@@ -427,7 +424,7 @@ impl ClientInner {
|
||||
let caches = ClientCaches {
|
||||
supported_versions: Cache::with_value(supported_versions),
|
||||
well_known: Cache::with_value(well_known),
|
||||
server_metadata: Mutex::new(TtlCache::new()),
|
||||
server_metadata: Cache::new(),
|
||||
homeserver_capabilities: Cache::new(),
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user