fix(oidc): make sure we keep track of an ongoing OIDC refresh up to the end

There's a lock making sure we're not doing multiple refreshes of an OIDC
token at the same time. Unfortunately, this lock could be dropped, if
the task spawned by the inner function was detached.

The lock must be held throughout the entire detached task's lifetime,
which this refactoring ensures, by setting the lock's result after
calling the inner function.
This commit is contained in:
Benjamin Bouvier
2024-11-21 11:03:54 +01:00
parent bc70f3c051
commit 48fbda844f
3 changed files with 103 additions and 93 deletions

View File

@@ -17,6 +17,8 @@
// TODO:(pixlwave) Move AuthenticationService from the FFI into this module.
// TODO:(poljar) Move the oidc and matrix_auth modules under this module.
use std::sync::Arc;
use as_variant::as_variant;
use matrix_sdk_base::SessionMeta;
use tokio::sync::{broadcast, Mutex, OnceCell};
@@ -58,7 +60,7 @@ pub(crate) struct AuthCtx {
pub(crate) handle_refresh_tokens: bool,
/// Lock making sure we're only doing one token refresh at a time.
pub(crate) refresh_token_lock: Mutex<Result<(), RefreshTokenError>>,
pub(crate) refresh_token_lock: Arc<Mutex<Result<(), RefreshTokenError>>>,
/// Session change publisher. Allows the subscriber to handle changes to the
/// session such as logging out when the access token is invalid or

View File

@@ -521,7 +521,7 @@ impl ClientBuilder {
let auth_ctx = Arc::new(AuthCtx {
handle_refresh_tokens: self.handle_refresh_tokens,
refresh_token_lock: Mutex::new(Ok(())),
refresh_token_lock: Arc::new(Mutex::new(Ok(()))),
session_change_sender: broadcast::Sender::new(1),
auth_data: OnceCell::default(),
reload_session_callback: OnceCell::default(),

View File

@@ -1292,88 +1292,64 @@ impl Oidc {
}
async fn refresh_access_token_inner(
&self,
self,
refresh_token: String,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
client_metadata: VerifiedClientMetadata,
latest_id_token: Option<IdToken<'static>>,
lock: Option<CrossProcessRefreshLockGuard>,
cross_process_lock: Option<CrossProcessRefreshLockGuard>,
) -> Result<(), OidcError> {
// Do not interrupt refresh access token requests and processing, by detaching
// the request sending and response processing.
trace!(
"Token refresh: attempting to refresh with refresh_token {:x}",
hash_str(&refresh_token)
);
let provider_metadata = self.provider_metadata().await?;
let new_tokens = self
.backend
.refresh_access_token(
provider_metadata,
credentials,
&client_metadata,
refresh_token.clone(),
latest_id_token.clone(),
)
.await
.map_err(OidcError::from)?;
let this = self.clone();
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
let credentials = data.credentials.clone();
let metadata = data.metadata.clone();
trace!(
"Token refresh: new refresh_token: {} / access_token: {:x}",
new_tokens
.refresh_token
.as_deref()
.map(|token| format!("{:x}", hash_str(token)))
.unwrap_or_else(|| "<none>".to_owned()),
hash_str(&new_tokens.access_token)
);
spawn(async move {
trace!(
"Token refresh: attempting to refresh with refresh_token {:x}",
hash_str(&refresh_token)
);
let tokens = OidcSessionTokens {
access_token: new_tokens.access_token,
refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)),
latest_id_token,
};
match this
.backend
.refresh_access_token(
provider_metadata,
credentials,
&metadata,
refresh_token.clone(),
latest_id_token.clone(),
)
.await
.map_err(OidcError::from)
{
Ok(new_tokens) => {
trace!(
"Token refresh: new refresh_token: {} / access_token: {:x}",
new_tokens
.refresh_token
.as_deref()
.map(|token| format!("{:x}", hash_str(token)))
.unwrap_or_else(|| "<none>".to_owned()),
hash_str(&new_tokens.access_token)
);
self.set_session_tokens(tokens.clone());
let tokens = OidcSessionTokens {
access_token: new_tokens.access_token,
refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)),
latest_id_token,
};
this.set_session_tokens(tokens.clone());
// Call the save_session_callback if set, while the optional lock is being held.
if let Some(save_session_callback) =
this.client.inner.auth_ctx.save_session_callback.get()
{
// Satisfies the save_session_callback invariant: set_session_tokens has
// been called just above.
if let Err(err) = save_session_callback(this.client.clone()) {
error!("when saving session after refresh: {err}");
}
}
if let Some(mut lock) = lock {
lock.save_in_memory_and_db(&tokens).await?;
}
_ = this
.client
.inner
.auth_ctx
.session_change_sender
.send(SessionChange::TokensRefreshed);
Ok(())
}
Err(err) => Err(err),
// Call the save_session_callback if set, while the optional lock is being held.
if let Some(save_session_callback) = self.client.inner.auth_ctx.save_session_callback.get()
{
// Satisfies the save_session_callback invariant: set_session_tokens has
// been called just above.
if let Err(err) = save_session_callback(self.client.clone()) {
error!("when saving session after refresh: {err}");
}
})
.await
.expect("joining")?;
}
if let Some(mut lock) = cross_process_lock {
lock.save_in_memory_and_db(&tokens).await?;
}
_ = self.client.inner.auth_ctx.session_change_sender.send(SessionChange::TokensRefreshed);
Ok(())
}
@@ -1393,10 +1369,6 @@ impl Oidc {
///
/// [`ClientBuilder::handle_refresh_tokens()`]: crate::ClientBuilder::handle_refresh_tokens()
pub async fn refresh_access_token(&self) -> Result<(), RefreshTokenError> {
let client = &self.client;
let refresh_status_lock = client.inner.auth_ctx.refresh_token_lock.try_lock();
macro_rules! fail {
($lock:expr, $err:expr) => {
let error = $err;
@@ -1405,6 +1377,10 @@ impl Oidc {
};
}
let client = &self.client;
let refresh_status_lock = client.inner.auth_ctx.refresh_token_lock.clone().try_lock_owned();
let Ok(mut refresh_status_guard) = refresh_status_lock else {
// There's already a request to refresh happening in the same process. Wait for
// it to finish.
@@ -1446,22 +1422,54 @@ impl Oidc {
fail!(refresh_status_guard, RefreshTokenError::RefreshTokenRequired);
};
match self
.refresh_access_token_inner(
refresh_token,
session_tokens.latest_id_token,
cross_process_guard,
)
.await
{
Ok(()) => {
*refresh_status_guard = Ok(());
Ok(())
let provider_metadata = match self.provider_metadata().await {
Ok(metadata) => metadata,
Err(err) => {
let err = Arc::new(err);
fail!(refresh_status_guard, RefreshTokenError::Oidc(err));
}
Err(error) => {
fail!(refresh_status_guard, RefreshTokenError::Oidc(error.into()));
};
let Some(auth_data) = self.data() else {
fail!(
refresh_status_guard,
RefreshTokenError::Oidc(Arc::new(OidcError::NotAuthenticated))
);
};
let credentials = auth_data.credentials.clone();
let client_metadata = auth_data.metadata.clone();
// Do not interrupt refresh access token requests and processing, by detaching
// the request sending and response processing.
// Make sure to keep the `refresh_status_guard` during the entire processing.
let this = self.clone();
spawn(async move {
match this
.refresh_access_token_inner(
refresh_token,
provider_metadata,
credentials,
client_metadata,
session_tokens.latest_id_token,
cross_process_guard,
)
.await
{
Ok(()) => {
*refresh_status_guard = Ok(());
Ok(())
}
Err(err) => {
let err = RefreshTokenError::Oidc(Arc::new(err));
fail!(refresh_status_guard, err);
}
}
}
})
.await
.expect("joining")
}
/// Log out from the currently authenticated session.