mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-03 21:45:51 -04:00
refactor(oauth): Reuse the AuthorizationServerMetadata when possible
Avoids repeated calls to the same endpoint in the same flow. Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
committed by
Ivan Enderlin
parent
b3e82a05db
commit
400c92fc89
@@ -103,7 +103,9 @@ impl OAuthAuthCodeUrlBuilder {
|
||||
login_hint,
|
||||
} = self;
|
||||
|
||||
oauth.use_registration_method(®istration_method).await?;
|
||||
let server_metadata = oauth.server_metadata().await?;
|
||||
|
||||
oauth.use_registration_method(&server_metadata, ®istration_method).await?;
|
||||
|
||||
let data = oauth.data().expect("OAuth 2.0 data should be set after registration");
|
||||
info!(
|
||||
@@ -112,8 +114,7 @@ impl OAuthAuthCodeUrlBuilder {
|
||||
"Authorizing scope via the OAuth 2.0 Authorization Code flow"
|
||||
);
|
||||
|
||||
let server_metadata = oauth.server_metadata().await?;
|
||||
let auth_url = AuthUrl::from_url(server_metadata.authorization_endpoint);
|
||||
let auth_url = AuthUrl::from_url(server_metadata.authorization_endpoint.clone());
|
||||
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let redirect_uri = RedirectUrl::from_url(redirect_uri);
|
||||
@@ -139,7 +140,7 @@ impl OAuthAuthCodeUrlBuilder {
|
||||
|
||||
data.authorization_data.lock().await.insert(
|
||||
state.clone(),
|
||||
AuthorizationValidationData { device_id, redirect_uri, pkce_verifier },
|
||||
AuthorizationValidationData { server_metadata, device_id, redirect_uri, pkce_verifier },
|
||||
);
|
||||
|
||||
Ok(OAuthAuthorizationData { url, state })
|
||||
|
||||
@@ -426,26 +426,23 @@ impl OAuth {
|
||||
/// while registering the client.
|
||||
async fn restore_or_register_client(
|
||||
&self,
|
||||
issuer: Url,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
registrations: &OAuthRegistrationStore,
|
||||
) -> std::result::Result<(), OAuthError> {
|
||||
if let Some(client_id) =
|
||||
registrations.client_id(&issuer).await.map_err(OAuthClientRegistrationError::from)?
|
||||
{
|
||||
self.restore_registered_client(issuer, client_id);
|
||||
) -> std::result::Result<(), OAuthClientRegistrationError> {
|
||||
if let Some(client_id) = registrations.client_id(&server_metadata.issuer).await? {
|
||||
self.restore_registered_client(server_metadata.issuer.clone(), client_id);
|
||||
|
||||
tracing::info!("OAuth 2.0 configuration loaded from disk.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
tracing::info!("Registering this client for OAuth 2.0.");
|
||||
let response = self.register_client(®istrations.metadata).await?;
|
||||
let response = self.register_client_inner(server_metadata, ®istrations.metadata).await?;
|
||||
|
||||
tracing::info!("Persisting OAuth 2.0 registration data.");
|
||||
registrations
|
||||
.set_and_write_client_id(response.client_id, issuer)
|
||||
.await
|
||||
.map_err(OAuthClientRegistrationError::from)?;
|
||||
.set_and_write_client_id(response.client_id, server_metadata.issuer.clone())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -458,6 +455,7 @@ impl OAuth {
|
||||
/// Returns an error if there was a problem using the registration method.
|
||||
async fn use_registration_method(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
method: &ClientRegistrationMethod,
|
||||
) -> std::result::Result<(), OAuthError> {
|
||||
if self.client_id().is_some() {
|
||||
@@ -468,15 +466,13 @@ impl OAuth {
|
||||
match method {
|
||||
ClientRegistrationMethod::None => return Err(OAuthError::NotRegistered),
|
||||
ClientRegistrationMethod::ClientId(client_id) => {
|
||||
let server_metadata = self.server_metadata().await?;
|
||||
self.restore_registered_client(server_metadata.issuer, client_id.clone());
|
||||
self.restore_registered_client(server_metadata.issuer.clone(), client_id.clone());
|
||||
}
|
||||
ClientRegistrationMethod::Metadata(client_metadata) => {
|
||||
self.register_client(client_metadata).await?;
|
||||
self.register_client_inner(server_metadata, client_metadata).await?;
|
||||
}
|
||||
ClientRegistrationMethod::Store(registrations) => {
|
||||
let server_metadata = self.server_metadata().await?;
|
||||
self.restore_or_register_client(server_metadata.issuer, registrations).await?
|
||||
self.restore_or_register_client(server_metadata, registrations).await?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -740,7 +736,14 @@ impl OAuth {
|
||||
client_metadata: &Raw<ClientMetadata>,
|
||||
) -> Result<ClientRegistrationResponse, OAuthError> {
|
||||
let server_metadata = self.server_metadata().await?;
|
||||
Ok(self.register_client_inner(&server_metadata, client_metadata).await?)
|
||||
}
|
||||
|
||||
async fn register_client_inner(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
client_metadata: &Raw<ClientMetadata>,
|
||||
) -> Result<ClientRegistrationResponse, OAuthClientRegistrationError> {
|
||||
let registration_endpoint = server_metadata
|
||||
.registration_endpoint
|
||||
.as_ref()
|
||||
@@ -752,7 +755,7 @@ impl OAuth {
|
||||
// The format of the credentials changes according to the client metadata that
|
||||
// was sent. Public clients only get a client ID.
|
||||
self.restore_registered_client(
|
||||
server_metadata.issuer,
|
||||
server_metadata.issuer.clone(),
|
||||
registration_response.client_id.clone(),
|
||||
);
|
||||
|
||||
@@ -1116,8 +1119,7 @@ impl OAuth {
|
||||
.remove(&auth_code.state)
|
||||
.ok_or(OAuthAuthorizationCodeError::InvalidState)?;
|
||||
|
||||
let server_metadata = self.server_metadata().await?;
|
||||
let token_uri = TokenUrl::from_url(server_metadata.token_endpoint);
|
||||
let token_uri = TokenUrl::from_url(validation_data.server_metadata.token_endpoint.clone());
|
||||
|
||||
let response = OAuthClient::new(client_id)
|
||||
.set_token_uri(token_uri)
|
||||
@@ -1159,6 +1161,7 @@ impl OAuth {
|
||||
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
|
||||
async fn request_device_authorization(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
device_id: Option<OwnedDeviceId>,
|
||||
) -> Result<oauth2::StandardDeviceAuthorizationResponse, qrcode::DeviceAuthorizationOAuthError>
|
||||
{
|
||||
@@ -1166,7 +1169,6 @@ impl OAuth {
|
||||
|
||||
let client_id = self.client_id().ok_or(OAuthError::NotRegistered)?.clone();
|
||||
|
||||
let server_metadata = self.server_metadata().await.map_err(OAuthError::from)?;
|
||||
let device_authorization_url = server_metadata
|
||||
.device_authorization_endpoint
|
||||
.clone()
|
||||
@@ -1187,14 +1189,14 @@ impl OAuth {
|
||||
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
|
||||
async fn exchange_device_code(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse,
|
||||
) -> Result<(), qrcode::DeviceAuthorizationOAuthError> {
|
||||
use oauth2::TokenResponse;
|
||||
|
||||
let client_id = self.client_id().ok_or(OAuthError::NotRegistered)?.clone();
|
||||
|
||||
let server_metadata = self.server_metadata().await.map_err(OAuthError::from)?;
|
||||
let token_uri = TokenUrl::from_url(server_metadata.token_endpoint);
|
||||
let token_uri = TokenUrl::from_url(server_metadata.token_endpoint.clone());
|
||||
|
||||
let response = OAuthClient::new(client_id)
|
||||
.set_token_uri(token_uri)
|
||||
@@ -1454,6 +1456,9 @@ pub struct UserSession {
|
||||
/// Authorization Code flow.
|
||||
#[derive(Debug)]
|
||||
struct AuthorizationValidationData {
|
||||
/// The metadata of the server,
|
||||
server_metadata: AuthorizationServerMetadata,
|
||||
|
||||
/// The device ID used in the scope.
|
||||
device_id: OwnedDeviceId,
|
||||
|
||||
|
||||
@@ -22,7 +22,10 @@ use matrix_sdk_base::{
|
||||
SessionMeta,
|
||||
};
|
||||
use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
|
||||
use ruma::OwnedDeviceId;
|
||||
use ruma::{
|
||||
api::client::discovery::get_authorization_server_metadata::msc2965::AuthorizationServerMetadata,
|
||||
OwnedDeviceId,
|
||||
};
|
||||
use tracing::trace;
|
||||
use vodozemac::{ecies::CheckCode, Curve25519PublicKey};
|
||||
|
||||
@@ -33,7 +36,10 @@ use super::{
|
||||
};
|
||||
#[cfg(doc)]
|
||||
use crate::authentication::oauth::OAuth;
|
||||
use crate::{authentication::oauth::ClientRegistrationMethod, Client};
|
||||
use crate::{
|
||||
authentication::oauth::{ClientRegistrationMethod, OAuthError},
|
||||
Client,
|
||||
};
|
||||
|
||||
async fn send_unexpected_message_error(
|
||||
channel: &mut EstablishedSecureChannel,
|
||||
@@ -112,7 +118,7 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
|
||||
|
||||
// Register the client with the OAuth 2.0 authorization server.
|
||||
trace!("Registering the client with the OAuth 2.0 authorization server.");
|
||||
self.register_client().await?;
|
||||
let server_metadata = self.register_client().await?;
|
||||
|
||||
// We want to use the Curve25519 public key for the device ID, so let's generate
|
||||
// a new vodozemac `Account` now.
|
||||
@@ -123,7 +129,8 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
|
||||
// Let's tell the OAuth 2.0 authorization server that we want to log in using
|
||||
// the device authorization grant described in [RFC8628](https://datatracker.ietf.org/doc/html/rfc8628).
|
||||
trace!("Requesting device authorization.");
|
||||
let auth_grant_response = self.request_device_authorization(device_id).await?;
|
||||
let auth_grant_response =
|
||||
self.request_device_authorization(&server_metadata, device_id).await?;
|
||||
|
||||
// Now we need to inform the other device of the login protocols we picked and
|
||||
// the URL they should use to log us in.
|
||||
@@ -160,7 +167,7 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
|
||||
// Let's now wait for the access token to be provided to use by the OAuth 2.0
|
||||
// authorization server.
|
||||
trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
|
||||
if let Err(e) = self.wait_for_tokens(&auth_grant_response).await {
|
||||
if let Err(e) = self.wait_for_tokens(&server_metadata, &auth_grant_response).await {
|
||||
// If we received an error, and it's one of the ones we should report to the
|
||||
// other side, do so now.
|
||||
if let Some(e) = e.as_request_token_error() {
|
||||
@@ -282,29 +289,37 @@ impl<'a> LoginWithQrCode<'a> {
|
||||
}
|
||||
|
||||
/// Register the client with the OAuth 2.0 authorization server.
|
||||
async fn register_client(&self) -> Result<(), DeviceAuthorizationOAuthError> {
|
||||
///
|
||||
/// Returns the authorization server metadata.
|
||||
async fn register_client(
|
||||
&self,
|
||||
) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
|
||||
let oauth = self.client.oauth();
|
||||
oauth.use_registration_method(&self.registration_method).await?;
|
||||
let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
|
||||
oauth.use_registration_method(&server_metadata, &self.registration_method).await?;
|
||||
|
||||
Ok(())
|
||||
Ok(server_metadata)
|
||||
}
|
||||
|
||||
async fn request_device_authorization(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
device_id: Curve25519PublicKey,
|
||||
) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
|
||||
let oauth = self.client.oauth();
|
||||
let response =
|
||||
oauth.request_device_authorization(Some(device_id.to_base64().into())).await?;
|
||||
let response = oauth
|
||||
.request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
|
||||
.await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn wait_for_tokens(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
auth_response: &StandardDeviceAuthorizationResponse,
|
||||
) -> Result<(), DeviceAuthorizationOAuthError> {
|
||||
let oauth = self.client.oauth();
|
||||
oauth.exchange_device_code(auth_response).await?;
|
||||
oauth.exchange_device_code(server_metadata, auth_response).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -414,7 +429,7 @@ mod test {
|
||||
let (sender, receiver) = tokio::sync::oneshot::channel();
|
||||
|
||||
let oauth_server = server.oauth();
|
||||
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
|
||||
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
|
||||
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
|
||||
oauth_server
|
||||
.mock_device_authorization()
|
||||
@@ -496,7 +511,7 @@ mod test {
|
||||
let (sender, receiver) = tokio::sync::oneshot::channel();
|
||||
|
||||
let oauth_server = server.oauth();
|
||||
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
|
||||
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
|
||||
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
|
||||
oauth_server
|
||||
.mock_device_authorization()
|
||||
@@ -631,7 +646,7 @@ mod test {
|
||||
oauth_server
|
||||
.mock_server_metadata()
|
||||
.ok_without_device_authorization()
|
||||
.expect(1..)
|
||||
.expect(1)
|
||||
.named("server_metadata")
|
||||
.mount()
|
||||
.await;
|
||||
|
||||
@@ -46,7 +46,7 @@ async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAu
|
||||
server.mock_who_am_i().ok().named("whoami").mount().await;
|
||||
|
||||
let oauth_server = server.oauth();
|
||||
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
|
||||
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
|
||||
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
|
||||
oauth_server.mock_token().ok().mount().await;
|
||||
|
||||
@@ -275,7 +275,7 @@ async fn test_login_url() -> anyhow::Result<()> {
|
||||
let issuer = Url::parse(&server.server().uri())?;
|
||||
|
||||
let oauth_server = server.oauth();
|
||||
oauth_server.mock_server_metadata().ok().expect(1..).mount().await;
|
||||
oauth_server.mock_server_metadata().ok().expect(3).mount().await;
|
||||
|
||||
let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await;
|
||||
let oauth = client.oauth();
|
||||
@@ -361,9 +361,8 @@ fn test_authorization_response() -> anyhow::Result<()> {
|
||||
#[async_test]
|
||||
async fn test_finish_login() -> anyhow::Result<()> {
|
||||
let server = MatrixMockServer::new().await;
|
||||
|
||||
let oauth_server = server.oauth();
|
||||
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
|
||||
let server_metadata = oauth_server.server_metadata();
|
||||
|
||||
let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await;
|
||||
let oauth = client.oauth();
|
||||
@@ -383,6 +382,7 @@ async fn test_finish_login() -> anyhow::Result<()> {
|
||||
let redirect_uri = REDIRECT_URI_STRING;
|
||||
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let auth_validation_data = AuthorizationValidationData {
|
||||
server_metadata: server_metadata.clone(),
|
||||
device_id: owned_device_id!("D3V1C31D"),
|
||||
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
|
||||
pkce_verifier,
|
||||
@@ -435,6 +435,7 @@ async fn test_finish_login() -> anyhow::Result<()> {
|
||||
let redirect_uri = REDIRECT_URI_STRING;
|
||||
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let auth_validation_data = AuthorizationValidationData {
|
||||
server_metadata: server_metadata.clone(),
|
||||
device_id: owned_device_id!("D3V1C31D"),
|
||||
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
|
||||
pkce_verifier,
|
||||
@@ -478,6 +479,7 @@ async fn test_finish_login() -> anyhow::Result<()> {
|
||||
let redirect_uri = REDIRECT_URI_STRING;
|
||||
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let auth_validation_data = AuthorizationValidationData {
|
||||
server_metadata,
|
||||
device_id: wrong_device_id.to_owned(),
|
||||
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
|
||||
pkce_verifier,
|
||||
|
||||
@@ -66,6 +66,14 @@ impl<'a> OAuthMockServer<'a> {
|
||||
fn mock_endpoint<T>(&self, mock: MockBuilder, endpoint: T) -> MockEndpoint<'a, T> {
|
||||
self.server.mock_endpoint(mock, endpoint)
|
||||
}
|
||||
|
||||
/// Get the mock OAuth 2.0 server metadata.
|
||||
pub fn server_metadata(&self) -> AuthorizationServerMetadata {
|
||||
MockServerMetadataBuilder::new(&self.server.server().uri())
|
||||
.build()
|
||||
.deserialize()
|
||||
.expect("mock OAuth 2.0 server metadata should deserialize successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// Specific mount endpoints.
|
||||
|
||||
Reference in New Issue
Block a user