From 48fbda844fbccbbded9fb3a4fec099ffd5ac6129 Mon Sep 17 00:00:00 2001 From: Benjamin Bouvier Date: Thu, 21 Nov 2024 11:03:54 +0100 Subject: [PATCH] fix(oidc): make sure we keep track of an ongoing OIDC refresh up to the end There's a lock making sure we're not doing multiple refreshes of an OIDC token at the same time. Unfortunately, this lock could be dropped, if the task spawned by the inner function was detached. The lock must be held throughout the entire detached task's lifetime, which this refactoring ensures, by setting the lock's result after calling the inner function. --- crates/matrix-sdk/src/authentication/mod.rs | 4 +- crates/matrix-sdk/src/client/builder/mod.rs | 2 +- crates/matrix-sdk/src/oidc/mod.rs | 190 ++++++++++---------- 3 files changed, 103 insertions(+), 93 deletions(-) diff --git a/crates/matrix-sdk/src/authentication/mod.rs b/crates/matrix-sdk/src/authentication/mod.rs index e7d06352d..81df8af3b 100644 --- a/crates/matrix-sdk/src/authentication/mod.rs +++ b/crates/matrix-sdk/src/authentication/mod.rs @@ -17,6 +17,8 @@ // TODO:(pixlwave) Move AuthenticationService from the FFI into this module. // TODO:(poljar) Move the oidc and matrix_auth modules under this module. +use std::sync::Arc; + use as_variant::as_variant; use matrix_sdk_base::SessionMeta; use tokio::sync::{broadcast, Mutex, OnceCell}; @@ -58,7 +60,7 @@ pub(crate) struct AuthCtx { pub(crate) handle_refresh_tokens: bool, /// Lock making sure we're only doing one token refresh at a time. - pub(crate) refresh_token_lock: Mutex>, + pub(crate) refresh_token_lock: Arc>>, /// Session change publisher. Allows the subscriber to handle changes to the /// session such as logging out when the access token is invalid or diff --git a/crates/matrix-sdk/src/client/builder/mod.rs b/crates/matrix-sdk/src/client/builder/mod.rs index f64ca2777..86b189a1f 100644 --- a/crates/matrix-sdk/src/client/builder/mod.rs +++ b/crates/matrix-sdk/src/client/builder/mod.rs @@ -521,7 +521,7 @@ impl ClientBuilder { let auth_ctx = Arc::new(AuthCtx { handle_refresh_tokens: self.handle_refresh_tokens, - refresh_token_lock: Mutex::new(Ok(())), + refresh_token_lock: Arc::new(Mutex::new(Ok(()))), session_change_sender: broadcast::Sender::new(1), auth_data: OnceCell::default(), reload_session_callback: OnceCell::default(), diff --git a/crates/matrix-sdk/src/oidc/mod.rs b/crates/matrix-sdk/src/oidc/mod.rs index ab686b23a..aba26c6f7 100644 --- a/crates/matrix-sdk/src/oidc/mod.rs +++ b/crates/matrix-sdk/src/oidc/mod.rs @@ -1292,88 +1292,64 @@ impl Oidc { } async fn refresh_access_token_inner( - &self, + self, refresh_token: String, + provider_metadata: VerifiedProviderMetadata, + credentials: ClientCredentials, + client_metadata: VerifiedClientMetadata, latest_id_token: Option>, - lock: Option, + cross_process_lock: Option, ) -> Result<(), OidcError> { - // Do not interrupt refresh access token requests and processing, by detaching - // the request sending and response processing. + trace!( + "Token refresh: attempting to refresh with refresh_token {:x}", + hash_str(&refresh_token) + ); - let provider_metadata = self.provider_metadata().await?; + let new_tokens = self + .backend + .refresh_access_token( + provider_metadata, + credentials, + &client_metadata, + refresh_token.clone(), + latest_id_token.clone(), + ) + .await + .map_err(OidcError::from)?; - let this = self.clone(); - let data = self.data().ok_or(OidcError::NotAuthenticated)?; - let credentials = data.credentials.clone(); - let metadata = data.metadata.clone(); + trace!( + "Token refresh: new refresh_token: {} / access_token: {:x}", + new_tokens + .refresh_token + .as_deref() + .map(|token| format!("{:x}", hash_str(token))) + .unwrap_or_else(|| "".to_owned()), + hash_str(&new_tokens.access_token) + ); - spawn(async move { - trace!( - "Token refresh: attempting to refresh with refresh_token {:x}", - hash_str(&refresh_token) - ); + let tokens = OidcSessionTokens { + access_token: new_tokens.access_token, + refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)), + latest_id_token, + }; - match this - .backend - .refresh_access_token( - provider_metadata, - credentials, - &metadata, - refresh_token.clone(), - latest_id_token.clone(), - ) - .await - .map_err(OidcError::from) - { - Ok(new_tokens) => { - trace!( - "Token refresh: new refresh_token: {} / access_token: {:x}", - new_tokens - .refresh_token - .as_deref() - .map(|token| format!("{:x}", hash_str(token))) - .unwrap_or_else(|| "".to_owned()), - hash_str(&new_tokens.access_token) - ); + self.set_session_tokens(tokens.clone()); - let tokens = OidcSessionTokens { - access_token: new_tokens.access_token, - refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)), - latest_id_token, - }; - - this.set_session_tokens(tokens.clone()); - - // Call the save_session_callback if set, while the optional lock is being held. - if let Some(save_session_callback) = - this.client.inner.auth_ctx.save_session_callback.get() - { - // Satisfies the save_session_callback invariant: set_session_tokens has - // been called just above. - if let Err(err) = save_session_callback(this.client.clone()) { - error!("when saving session after refresh: {err}"); - } - } - - if let Some(mut lock) = lock { - lock.save_in_memory_and_db(&tokens).await?; - } - - _ = this - .client - .inner - .auth_ctx - .session_change_sender - .send(SessionChange::TokensRefreshed); - - Ok(()) - } - - Err(err) => Err(err), + // Call the save_session_callback if set, while the optional lock is being held. + if let Some(save_session_callback) = self.client.inner.auth_ctx.save_session_callback.get() + { + // Satisfies the save_session_callback invariant: set_session_tokens has + // been called just above. + if let Err(err) = save_session_callback(self.client.clone()) { + error!("when saving session after refresh: {err}"); } - }) - .await - .expect("joining")?; + } + + if let Some(mut lock) = cross_process_lock { + lock.save_in_memory_and_db(&tokens).await?; + } + + _ = self.client.inner.auth_ctx.session_change_sender.send(SessionChange::TokensRefreshed); Ok(()) } @@ -1393,10 +1369,6 @@ impl Oidc { /// /// [`ClientBuilder::handle_refresh_tokens()`]: crate::ClientBuilder::handle_refresh_tokens() pub async fn refresh_access_token(&self) -> Result<(), RefreshTokenError> { - let client = &self.client; - - let refresh_status_lock = client.inner.auth_ctx.refresh_token_lock.try_lock(); - macro_rules! fail { ($lock:expr, $err:expr) => { let error = $err; @@ -1405,6 +1377,10 @@ impl Oidc { }; } + let client = &self.client; + + let refresh_status_lock = client.inner.auth_ctx.refresh_token_lock.clone().try_lock_owned(); + let Ok(mut refresh_status_guard) = refresh_status_lock else { // There's already a request to refresh happening in the same process. Wait for // it to finish. @@ -1446,22 +1422,54 @@ impl Oidc { fail!(refresh_status_guard, RefreshTokenError::RefreshTokenRequired); }; - match self - .refresh_access_token_inner( - refresh_token, - session_tokens.latest_id_token, - cross_process_guard, - ) - .await - { - Ok(()) => { - *refresh_status_guard = Ok(()); - Ok(()) + let provider_metadata = match self.provider_metadata().await { + Ok(metadata) => metadata, + Err(err) => { + let err = Arc::new(err); + fail!(refresh_status_guard, RefreshTokenError::Oidc(err)); } - Err(error) => { - fail!(refresh_status_guard, RefreshTokenError::Oidc(error.into())); + }; + + let Some(auth_data) = self.data() else { + fail!( + refresh_status_guard, + RefreshTokenError::Oidc(Arc::new(OidcError::NotAuthenticated)) + ); + }; + + let credentials = auth_data.credentials.clone(); + let client_metadata = auth_data.metadata.clone(); + + // Do not interrupt refresh access token requests and processing, by detaching + // the request sending and response processing. + // Make sure to keep the `refresh_status_guard` during the entire processing. + + let this = self.clone(); + + spawn(async move { + match this + .refresh_access_token_inner( + refresh_token, + provider_metadata, + credentials, + client_metadata, + session_tokens.latest_id_token, + cross_process_guard, + ) + .await + { + Ok(()) => { + *refresh_status_guard = Ok(()); + Ok(()) + } + Err(err) => { + let err = RefreshTokenError::Oidc(Arc::new(err)); + fail!(refresh_status_guard, err); + } } - } + }) + .await + .expect("joining") } /// Log out from the currently authenticated session.