diff --git a/crates/matrix-sdk/src/authentication/mod.rs b/crates/matrix-sdk/src/authentication/mod.rs index e7d06352d..81df8af3b 100644 --- a/crates/matrix-sdk/src/authentication/mod.rs +++ b/crates/matrix-sdk/src/authentication/mod.rs @@ -17,6 +17,8 @@ // TODO:(pixlwave) Move AuthenticationService from the FFI into this module. // TODO:(poljar) Move the oidc and matrix_auth modules under this module. +use std::sync::Arc; + use as_variant::as_variant; use matrix_sdk_base::SessionMeta; use tokio::sync::{broadcast, Mutex, OnceCell}; @@ -58,7 +60,7 @@ pub(crate) struct AuthCtx { pub(crate) handle_refresh_tokens: bool, /// Lock making sure we're only doing one token refresh at a time. - pub(crate) refresh_token_lock: Mutex>, + pub(crate) refresh_token_lock: Arc>>, /// Session change publisher. Allows the subscriber to handle changes to the /// session such as logging out when the access token is invalid or diff --git a/crates/matrix-sdk/src/client/builder/mod.rs b/crates/matrix-sdk/src/client/builder/mod.rs index f64ca2777..86b189a1f 100644 --- a/crates/matrix-sdk/src/client/builder/mod.rs +++ b/crates/matrix-sdk/src/client/builder/mod.rs @@ -521,7 +521,7 @@ impl ClientBuilder { let auth_ctx = Arc::new(AuthCtx { handle_refresh_tokens: self.handle_refresh_tokens, - refresh_token_lock: Mutex::new(Ok(())), + refresh_token_lock: Arc::new(Mutex::new(Ok(()))), session_change_sender: broadcast::Sender::new(1), auth_data: OnceCell::default(), reload_session_callback: OnceCell::default(), diff --git a/crates/matrix-sdk/src/oidc/mod.rs b/crates/matrix-sdk/src/oidc/mod.rs index ab686b23a..aba26c6f7 100644 --- a/crates/matrix-sdk/src/oidc/mod.rs +++ b/crates/matrix-sdk/src/oidc/mod.rs @@ -1292,88 +1292,64 @@ impl Oidc { } async fn refresh_access_token_inner( - &self, + self, refresh_token: String, + provider_metadata: VerifiedProviderMetadata, + credentials: ClientCredentials, + client_metadata: VerifiedClientMetadata, latest_id_token: Option>, - lock: Option, + cross_process_lock: Option, ) -> Result<(), OidcError> { - // Do not interrupt refresh access token requests and processing, by detaching - // the request sending and response processing. + trace!( + "Token refresh: attempting to refresh with refresh_token {:x}", + hash_str(&refresh_token) + ); - let provider_metadata = self.provider_metadata().await?; + let new_tokens = self + .backend + .refresh_access_token( + provider_metadata, + credentials, + &client_metadata, + refresh_token.clone(), + latest_id_token.clone(), + ) + .await + .map_err(OidcError::from)?; - let this = self.clone(); - let data = self.data().ok_or(OidcError::NotAuthenticated)?; - let credentials = data.credentials.clone(); - let metadata = data.metadata.clone(); + trace!( + "Token refresh: new refresh_token: {} / access_token: {:x}", + new_tokens + .refresh_token + .as_deref() + .map(|token| format!("{:x}", hash_str(token))) + .unwrap_or_else(|| "".to_owned()), + hash_str(&new_tokens.access_token) + ); - spawn(async move { - trace!( - "Token refresh: attempting to refresh with refresh_token {:x}", - hash_str(&refresh_token) - ); + let tokens = OidcSessionTokens { + access_token: new_tokens.access_token, + refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)), + latest_id_token, + }; - match this - .backend - .refresh_access_token( - provider_metadata, - credentials, - &metadata, - refresh_token.clone(), - latest_id_token.clone(), - ) - .await - .map_err(OidcError::from) - { - Ok(new_tokens) => { - trace!( - "Token refresh: new refresh_token: {} / access_token: {:x}", - new_tokens - .refresh_token - .as_deref() - .map(|token| format!("{:x}", hash_str(token))) - .unwrap_or_else(|| "".to_owned()), - hash_str(&new_tokens.access_token) - ); + self.set_session_tokens(tokens.clone()); - let tokens = OidcSessionTokens { - access_token: new_tokens.access_token, - refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)), - latest_id_token, - }; - - this.set_session_tokens(tokens.clone()); - - // Call the save_session_callback if set, while the optional lock is being held. - if let Some(save_session_callback) = - this.client.inner.auth_ctx.save_session_callback.get() - { - // Satisfies the save_session_callback invariant: set_session_tokens has - // been called just above. - if let Err(err) = save_session_callback(this.client.clone()) { - error!("when saving session after refresh: {err}"); - } - } - - if let Some(mut lock) = lock { - lock.save_in_memory_and_db(&tokens).await?; - } - - _ = this - .client - .inner - .auth_ctx - .session_change_sender - .send(SessionChange::TokensRefreshed); - - Ok(()) - } - - Err(err) => Err(err), + // Call the save_session_callback if set, while the optional lock is being held. + if let Some(save_session_callback) = self.client.inner.auth_ctx.save_session_callback.get() + { + // Satisfies the save_session_callback invariant: set_session_tokens has + // been called just above. + if let Err(err) = save_session_callback(self.client.clone()) { + error!("when saving session after refresh: {err}"); } - }) - .await - .expect("joining")?; + } + + if let Some(mut lock) = cross_process_lock { + lock.save_in_memory_and_db(&tokens).await?; + } + + _ = self.client.inner.auth_ctx.session_change_sender.send(SessionChange::TokensRefreshed); Ok(()) } @@ -1393,10 +1369,6 @@ impl Oidc { /// /// [`ClientBuilder::handle_refresh_tokens()`]: crate::ClientBuilder::handle_refresh_tokens() pub async fn refresh_access_token(&self) -> Result<(), RefreshTokenError> { - let client = &self.client; - - let refresh_status_lock = client.inner.auth_ctx.refresh_token_lock.try_lock(); - macro_rules! fail { ($lock:expr, $err:expr) => { let error = $err; @@ -1405,6 +1377,10 @@ impl Oidc { }; } + let client = &self.client; + + let refresh_status_lock = client.inner.auth_ctx.refresh_token_lock.clone().try_lock_owned(); + let Ok(mut refresh_status_guard) = refresh_status_lock else { // There's already a request to refresh happening in the same process. Wait for // it to finish. @@ -1446,22 +1422,54 @@ impl Oidc { fail!(refresh_status_guard, RefreshTokenError::RefreshTokenRequired); }; - match self - .refresh_access_token_inner( - refresh_token, - session_tokens.latest_id_token, - cross_process_guard, - ) - .await - { - Ok(()) => { - *refresh_status_guard = Ok(()); - Ok(()) + let provider_metadata = match self.provider_metadata().await { + Ok(metadata) => metadata, + Err(err) => { + let err = Arc::new(err); + fail!(refresh_status_guard, RefreshTokenError::Oidc(err)); } - Err(error) => { - fail!(refresh_status_guard, RefreshTokenError::Oidc(error.into())); + }; + + let Some(auth_data) = self.data() else { + fail!( + refresh_status_guard, + RefreshTokenError::Oidc(Arc::new(OidcError::NotAuthenticated)) + ); + }; + + let credentials = auth_data.credentials.clone(); + let client_metadata = auth_data.metadata.clone(); + + // Do not interrupt refresh access token requests and processing, by detaching + // the request sending and response processing. + // Make sure to keep the `refresh_status_guard` during the entire processing. + + let this = self.clone(); + + spawn(async move { + match this + .refresh_access_token_inner( + refresh_token, + provider_metadata, + credentials, + client_metadata, + session_tokens.latest_id_token, + cross_process_guard, + ) + .await + { + Ok(()) => { + *refresh_status_guard = Ok(()); + Ok(()) + } + Err(err) => { + let err = RefreshTokenError::Oidc(Arc::new(err)); + fail!(refresh_status_guard, err); + } } - } + }) + .await + .expect("joining") } /// Log out from the currently authenticated session.