refactor(oauth): Reuse the AuthorizationServerMetadata when possible

Avoids repeated calls to the same endpoint in the same flow.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-03-21 15:38:29 +01:00
committed by Ivan Enderlin
parent b3e82a05db
commit 400c92fc89
5 changed files with 74 additions and 43 deletions

View File

@@ -103,7 +103,9 @@ impl OAuthAuthCodeUrlBuilder {
login_hint,
} = self;
oauth.use_registration_method(&registration_method).await?;
let server_metadata = oauth.server_metadata().await?;
oauth.use_registration_method(&server_metadata, &registration_method).await?;
let data = oauth.data().expect("OAuth 2.0 data should be set after registration");
info!(
@@ -112,8 +114,7 @@ impl OAuthAuthCodeUrlBuilder {
"Authorizing scope via the OAuth 2.0 Authorization Code flow"
);
let server_metadata = oauth.server_metadata().await?;
let auth_url = AuthUrl::from_url(server_metadata.authorization_endpoint);
let auth_url = AuthUrl::from_url(server_metadata.authorization_endpoint.clone());
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let redirect_uri = RedirectUrl::from_url(redirect_uri);
@@ -139,7 +140,7 @@ impl OAuthAuthCodeUrlBuilder {
data.authorization_data.lock().await.insert(
state.clone(),
AuthorizationValidationData { device_id, redirect_uri, pkce_verifier },
AuthorizationValidationData { server_metadata, device_id, redirect_uri, pkce_verifier },
);
Ok(OAuthAuthorizationData { url, state })

View File

@@ -426,26 +426,23 @@ impl OAuth {
/// while registering the client.
async fn restore_or_register_client(
&self,
issuer: Url,
server_metadata: &AuthorizationServerMetadata,
registrations: &OAuthRegistrationStore,
) -> std::result::Result<(), OAuthError> {
if let Some(client_id) =
registrations.client_id(&issuer).await.map_err(OAuthClientRegistrationError::from)?
{
self.restore_registered_client(issuer, client_id);
) -> std::result::Result<(), OAuthClientRegistrationError> {
if let Some(client_id) = registrations.client_id(&server_metadata.issuer).await? {
self.restore_registered_client(server_metadata.issuer.clone(), client_id);
tracing::info!("OAuth 2.0 configuration loaded from disk.");
return Ok(());
};
tracing::info!("Registering this client for OAuth 2.0.");
let response = self.register_client(&registrations.metadata).await?;
let response = self.register_client_inner(server_metadata, &registrations.metadata).await?;
tracing::info!("Persisting OAuth 2.0 registration data.");
registrations
.set_and_write_client_id(response.client_id, issuer)
.await
.map_err(OAuthClientRegistrationError::from)?;
.set_and_write_client_id(response.client_id, server_metadata.issuer.clone())
.await?;
Ok(())
}
@@ -458,6 +455,7 @@ impl OAuth {
/// Returns an error if there was a problem using the registration method.
async fn use_registration_method(
&self,
server_metadata: &AuthorizationServerMetadata,
method: &ClientRegistrationMethod,
) -> std::result::Result<(), OAuthError> {
if self.client_id().is_some() {
@@ -468,15 +466,13 @@ impl OAuth {
match method {
ClientRegistrationMethod::None => return Err(OAuthError::NotRegistered),
ClientRegistrationMethod::ClientId(client_id) => {
let server_metadata = self.server_metadata().await?;
self.restore_registered_client(server_metadata.issuer, client_id.clone());
self.restore_registered_client(server_metadata.issuer.clone(), client_id.clone());
}
ClientRegistrationMethod::Metadata(client_metadata) => {
self.register_client(client_metadata).await?;
self.register_client_inner(server_metadata, client_metadata).await?;
}
ClientRegistrationMethod::Store(registrations) => {
let server_metadata = self.server_metadata().await?;
self.restore_or_register_client(server_metadata.issuer, registrations).await?
self.restore_or_register_client(server_metadata, registrations).await?
}
}
@@ -740,7 +736,14 @@ impl OAuth {
client_metadata: &Raw<ClientMetadata>,
) -> Result<ClientRegistrationResponse, OAuthError> {
let server_metadata = self.server_metadata().await?;
Ok(self.register_client_inner(&server_metadata, client_metadata).await?)
}
async fn register_client_inner(
&self,
server_metadata: &AuthorizationServerMetadata,
client_metadata: &Raw<ClientMetadata>,
) -> Result<ClientRegistrationResponse, OAuthClientRegistrationError> {
let registration_endpoint = server_metadata
.registration_endpoint
.as_ref()
@@ -752,7 +755,7 @@ impl OAuth {
// The format of the credentials changes according to the client metadata that
// was sent. Public clients only get a client ID.
self.restore_registered_client(
server_metadata.issuer,
server_metadata.issuer.clone(),
registration_response.client_id.clone(),
);
@@ -1116,8 +1119,7 @@ impl OAuth {
.remove(&auth_code.state)
.ok_or(OAuthAuthorizationCodeError::InvalidState)?;
let server_metadata = self.server_metadata().await?;
let token_uri = TokenUrl::from_url(server_metadata.token_endpoint);
let token_uri = TokenUrl::from_url(validation_data.server_metadata.token_endpoint.clone());
let response = OAuthClient::new(client_id)
.set_token_uri(token_uri)
@@ -1159,6 +1161,7 @@ impl OAuth {
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn request_device_authorization(
&self,
server_metadata: &AuthorizationServerMetadata,
device_id: Option<OwnedDeviceId>,
) -> Result<oauth2::StandardDeviceAuthorizationResponse, qrcode::DeviceAuthorizationOAuthError>
{
@@ -1166,7 +1169,6 @@ impl OAuth {
let client_id = self.client_id().ok_or(OAuthError::NotRegistered)?.clone();
let server_metadata = self.server_metadata().await.map_err(OAuthError::from)?;
let device_authorization_url = server_metadata
.device_authorization_endpoint
.clone()
@@ -1187,14 +1189,14 @@ impl OAuth {
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn exchange_device_code(
&self,
server_metadata: &AuthorizationServerMetadata,
device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse,
) -> Result<(), qrcode::DeviceAuthorizationOAuthError> {
use oauth2::TokenResponse;
let client_id = self.client_id().ok_or(OAuthError::NotRegistered)?.clone();
let server_metadata = self.server_metadata().await.map_err(OAuthError::from)?;
let token_uri = TokenUrl::from_url(server_metadata.token_endpoint);
let token_uri = TokenUrl::from_url(server_metadata.token_endpoint.clone());
let response = OAuthClient::new(client_id)
.set_token_uri(token_uri)
@@ -1454,6 +1456,9 @@ pub struct UserSession {
/// Authorization Code flow.
#[derive(Debug)]
struct AuthorizationValidationData {
/// The metadata of the server,
server_metadata: AuthorizationServerMetadata,
/// The device ID used in the scope.
device_id: OwnedDeviceId,

View File

@@ -22,7 +22,10 @@ use matrix_sdk_base::{
SessionMeta,
};
use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
use ruma::OwnedDeviceId;
use ruma::{
api::client::discovery::get_authorization_server_metadata::msc2965::AuthorizationServerMetadata,
OwnedDeviceId,
};
use tracing::trace;
use vodozemac::{ecies::CheckCode, Curve25519PublicKey};
@@ -33,7 +36,10 @@ use super::{
};
#[cfg(doc)]
use crate::authentication::oauth::OAuth;
use crate::{authentication::oauth::ClientRegistrationMethod, Client};
use crate::{
authentication::oauth::{ClientRegistrationMethod, OAuthError},
Client,
};
async fn send_unexpected_message_error(
channel: &mut EstablishedSecureChannel,
@@ -112,7 +118,7 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
// Register the client with the OAuth 2.0 authorization server.
trace!("Registering the client with the OAuth 2.0 authorization server.");
self.register_client().await?;
let server_metadata = self.register_client().await?;
// We want to use the Curve25519 public key for the device ID, so let's generate
// a new vodozemac `Account` now.
@@ -123,7 +129,8 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
// Let's tell the OAuth 2.0 authorization server that we want to log in using
// the device authorization grant described in [RFC8628](https://datatracker.ietf.org/doc/html/rfc8628).
trace!("Requesting device authorization.");
let auth_grant_response = self.request_device_authorization(device_id).await?;
let auth_grant_response =
self.request_device_authorization(&server_metadata, device_id).await?;
// Now we need to inform the other device of the login protocols we picked and
// the URL they should use to log us in.
@@ -160,7 +167,7 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
// Let's now wait for the access token to be provided to use by the OAuth 2.0
// authorization server.
trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
if let Err(e) = self.wait_for_tokens(&auth_grant_response).await {
if let Err(e) = self.wait_for_tokens(&server_metadata, &auth_grant_response).await {
// If we received an error, and it's one of the ones we should report to the
// other side, do so now.
if let Some(e) = e.as_request_token_error() {
@@ -282,29 +289,37 @@ impl<'a> LoginWithQrCode<'a> {
}
/// Register the client with the OAuth 2.0 authorization server.
async fn register_client(&self) -> Result<(), DeviceAuthorizationOAuthError> {
///
/// Returns the authorization server metadata.
async fn register_client(
&self,
) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
let oauth = self.client.oauth();
oauth.use_registration_method(&self.registration_method).await?;
let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
oauth.use_registration_method(&server_metadata, &self.registration_method).await?;
Ok(())
Ok(server_metadata)
}
async fn request_device_authorization(
&self,
server_metadata: &AuthorizationServerMetadata,
device_id: Curve25519PublicKey,
) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
let oauth = self.client.oauth();
let response =
oauth.request_device_authorization(Some(device_id.to_base64().into())).await?;
let response = oauth
.request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
.await?;
Ok(response)
}
async fn wait_for_tokens(
&self,
server_metadata: &AuthorizationServerMetadata,
auth_response: &StandardDeviceAuthorizationResponse,
) -> Result<(), DeviceAuthorizationOAuthError> {
let oauth = self.client.oauth();
oauth.exchange_device_code(auth_response).await?;
oauth.exchange_device_code(server_metadata, auth_response).await?;
Ok(())
}
}
@@ -414,7 +429,7 @@ mod test {
let (sender, receiver) = tokio::sync::oneshot::channel();
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
oauth_server
.mock_device_authorization()
@@ -496,7 +511,7 @@ mod test {
let (sender, receiver) = tokio::sync::oneshot::channel();
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
oauth_server
.mock_device_authorization()
@@ -631,7 +646,7 @@ mod test {
oauth_server
.mock_server_metadata()
.ok_without_device_authorization()
.expect(1..)
.expect(1)
.named("server_metadata")
.mount()
.await;

View File

@@ -46,7 +46,7 @@ async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAu
server.mock_who_am_i().ok().named("whoami").mount().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
oauth_server.mock_token().ok().mount().await;
@@ -275,7 +275,7 @@ async fn test_login_url() -> anyhow::Result<()> {
let issuer = Url::parse(&server.server().uri())?;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).mount().await;
oauth_server.mock_server_metadata().ok().expect(3).mount().await;
let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await;
let oauth = client.oauth();
@@ -361,9 +361,8 @@ fn test_authorization_response() -> anyhow::Result<()> {
#[async_test]
async fn test_finish_login() -> anyhow::Result<()> {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
let server_metadata = oauth_server.server_metadata();
let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await;
let oauth = client.oauth();
@@ -383,6 +382,7 @@ async fn test_finish_login() -> anyhow::Result<()> {
let redirect_uri = REDIRECT_URI_STRING;
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let auth_validation_data = AuthorizationValidationData {
server_metadata: server_metadata.clone(),
device_id: owned_device_id!("D3V1C31D"),
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
pkce_verifier,
@@ -435,6 +435,7 @@ async fn test_finish_login() -> anyhow::Result<()> {
let redirect_uri = REDIRECT_URI_STRING;
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let auth_validation_data = AuthorizationValidationData {
server_metadata: server_metadata.clone(),
device_id: owned_device_id!("D3V1C31D"),
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
pkce_verifier,
@@ -478,6 +479,7 @@ async fn test_finish_login() -> anyhow::Result<()> {
let redirect_uri = REDIRECT_URI_STRING;
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let auth_validation_data = AuthorizationValidationData {
server_metadata,
device_id: wrong_device_id.to_owned(),
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
pkce_verifier,

View File

@@ -66,6 +66,14 @@ impl<'a> OAuthMockServer<'a> {
fn mock_endpoint<T>(&self, mock: MockBuilder, endpoint: T) -> MockEndpoint<'a, T> {
self.server.mock_endpoint(mock, endpoint)
}
/// Get the mock OAuth 2.0 server metadata.
pub fn server_metadata(&self) -> AuthorizationServerMetadata {
MockServerMetadataBuilder::new(&self.server.server().uri())
.build()
.deserialize()
.expect("mock OAuth 2.0 server metadata should deserialize successfully")
}
}
// Specific mount endpoints.