diff --git a/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs b/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs index a356648de..8c88b47de 100644 --- a/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs +++ b/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs @@ -103,7 +103,9 @@ impl OAuthAuthCodeUrlBuilder { login_hint, } = self; - oauth.use_registration_method(®istration_method).await?; + let server_metadata = oauth.server_metadata().await?; + + oauth.use_registration_method(&server_metadata, ®istration_method).await?; let data = oauth.data().expect("OAuth 2.0 data should be set after registration"); info!( @@ -112,8 +114,7 @@ impl OAuthAuthCodeUrlBuilder { "Authorizing scope via the OAuth 2.0 Authorization Code flow" ); - let server_metadata = oauth.server_metadata().await?; - let auth_url = AuthUrl::from_url(server_metadata.authorization_endpoint); + let auth_url = AuthUrl::from_url(server_metadata.authorization_endpoint.clone()); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let redirect_uri = RedirectUrl::from_url(redirect_uri); @@ -139,7 +140,7 @@ impl OAuthAuthCodeUrlBuilder { data.authorization_data.lock().await.insert( state.clone(), - AuthorizationValidationData { device_id, redirect_uri, pkce_verifier }, + AuthorizationValidationData { server_metadata, device_id, redirect_uri, pkce_verifier }, ); Ok(OAuthAuthorizationData { url, state }) diff --git a/crates/matrix-sdk/src/authentication/oauth/mod.rs b/crates/matrix-sdk/src/authentication/oauth/mod.rs index 16392b629..a30062af7 100644 --- a/crates/matrix-sdk/src/authentication/oauth/mod.rs +++ b/crates/matrix-sdk/src/authentication/oauth/mod.rs @@ -426,26 +426,23 @@ impl OAuth { /// while registering the client. async fn restore_or_register_client( &self, - issuer: Url, + server_metadata: &AuthorizationServerMetadata, registrations: &OAuthRegistrationStore, - ) -> std::result::Result<(), OAuthError> { - if let Some(client_id) = - registrations.client_id(&issuer).await.map_err(OAuthClientRegistrationError::from)? - { - self.restore_registered_client(issuer, client_id); + ) -> std::result::Result<(), OAuthClientRegistrationError> { + if let Some(client_id) = registrations.client_id(&server_metadata.issuer).await? { + self.restore_registered_client(server_metadata.issuer.clone(), client_id); tracing::info!("OAuth 2.0 configuration loaded from disk."); return Ok(()); }; tracing::info!("Registering this client for OAuth 2.0."); - let response = self.register_client(®istrations.metadata).await?; + let response = self.register_client_inner(server_metadata, ®istrations.metadata).await?; tracing::info!("Persisting OAuth 2.0 registration data."); registrations - .set_and_write_client_id(response.client_id, issuer) - .await - .map_err(OAuthClientRegistrationError::from)?; + .set_and_write_client_id(response.client_id, server_metadata.issuer.clone()) + .await?; Ok(()) } @@ -458,6 +455,7 @@ impl OAuth { /// Returns an error if there was a problem using the registration method. async fn use_registration_method( &self, + server_metadata: &AuthorizationServerMetadata, method: &ClientRegistrationMethod, ) -> std::result::Result<(), OAuthError> { if self.client_id().is_some() { @@ -468,15 +466,13 @@ impl OAuth { match method { ClientRegistrationMethod::None => return Err(OAuthError::NotRegistered), ClientRegistrationMethod::ClientId(client_id) => { - let server_metadata = self.server_metadata().await?; - self.restore_registered_client(server_metadata.issuer, client_id.clone()); + self.restore_registered_client(server_metadata.issuer.clone(), client_id.clone()); } ClientRegistrationMethod::Metadata(client_metadata) => { - self.register_client(client_metadata).await?; + self.register_client_inner(server_metadata, client_metadata).await?; } ClientRegistrationMethod::Store(registrations) => { - let server_metadata = self.server_metadata().await?; - self.restore_or_register_client(server_metadata.issuer, registrations).await? + self.restore_or_register_client(server_metadata, registrations).await? } } @@ -740,7 +736,14 @@ impl OAuth { client_metadata: &Raw, ) -> Result { let server_metadata = self.server_metadata().await?; + Ok(self.register_client_inner(&server_metadata, client_metadata).await?) + } + async fn register_client_inner( + &self, + server_metadata: &AuthorizationServerMetadata, + client_metadata: &Raw, + ) -> Result { let registration_endpoint = server_metadata .registration_endpoint .as_ref() @@ -752,7 +755,7 @@ impl OAuth { // The format of the credentials changes according to the client metadata that // was sent. Public clients only get a client ID. self.restore_registered_client( - server_metadata.issuer, + server_metadata.issuer.clone(), registration_response.client_id.clone(), ); @@ -1116,8 +1119,7 @@ impl OAuth { .remove(&auth_code.state) .ok_or(OAuthAuthorizationCodeError::InvalidState)?; - let server_metadata = self.server_metadata().await?; - let token_uri = TokenUrl::from_url(server_metadata.token_endpoint); + let token_uri = TokenUrl::from_url(validation_data.server_metadata.token_endpoint.clone()); let response = OAuthClient::new(client_id) .set_token_uri(token_uri) @@ -1159,6 +1161,7 @@ impl OAuth { #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] async fn request_device_authorization( &self, + server_metadata: &AuthorizationServerMetadata, device_id: Option, ) -> Result { @@ -1166,7 +1169,6 @@ impl OAuth { let client_id = self.client_id().ok_or(OAuthError::NotRegistered)?.clone(); - let server_metadata = self.server_metadata().await.map_err(OAuthError::from)?; let device_authorization_url = server_metadata .device_authorization_endpoint .clone() @@ -1187,14 +1189,14 @@ impl OAuth { #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] async fn exchange_device_code( &self, + server_metadata: &AuthorizationServerMetadata, device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse, ) -> Result<(), qrcode::DeviceAuthorizationOAuthError> { use oauth2::TokenResponse; let client_id = self.client_id().ok_or(OAuthError::NotRegistered)?.clone(); - let server_metadata = self.server_metadata().await.map_err(OAuthError::from)?; - let token_uri = TokenUrl::from_url(server_metadata.token_endpoint); + let token_uri = TokenUrl::from_url(server_metadata.token_endpoint.clone()); let response = OAuthClient::new(client_id) .set_token_uri(token_uri) @@ -1454,6 +1456,9 @@ pub struct UserSession { /// Authorization Code flow. #[derive(Debug)] struct AuthorizationValidationData { + /// The metadata of the server, + server_metadata: AuthorizationServerMetadata, + /// The device ID used in the scope. device_id: OwnedDeviceId, diff --git a/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs b/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs index 6bee65e49..985c587cc 100644 --- a/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs +++ b/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs @@ -22,7 +22,10 @@ use matrix_sdk_base::{ SessionMeta, }; use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse}; -use ruma::OwnedDeviceId; +use ruma::{ + api::client::discovery::get_authorization_server_metadata::msc2965::AuthorizationServerMetadata, + OwnedDeviceId, +}; use tracing::trace; use vodozemac::{ecies::CheckCode, Curve25519PublicKey}; @@ -33,7 +36,10 @@ use super::{ }; #[cfg(doc)] use crate::authentication::oauth::OAuth; -use crate::{authentication::oauth::ClientRegistrationMethod, Client}; +use crate::{ + authentication::oauth::{ClientRegistrationMethod, OAuthError}, + Client, +}; async fn send_unexpected_message_error( channel: &mut EstablishedSecureChannel, @@ -112,7 +118,7 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> { // Register the client with the OAuth 2.0 authorization server. trace!("Registering the client with the OAuth 2.0 authorization server."); - self.register_client().await?; + let server_metadata = self.register_client().await?; // We want to use the Curve25519 public key for the device ID, so let's generate // a new vodozemac `Account` now. @@ -123,7 +129,8 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> { // Let's tell the OAuth 2.0 authorization server that we want to log in using // the device authorization grant described in [RFC8628](https://datatracker.ietf.org/doc/html/rfc8628). trace!("Requesting device authorization."); - let auth_grant_response = self.request_device_authorization(device_id).await?; + let auth_grant_response = + self.request_device_authorization(&server_metadata, device_id).await?; // Now we need to inform the other device of the login protocols we picked and // the URL they should use to log us in. @@ -160,7 +167,7 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> { // Let's now wait for the access token to be provided to use by the OAuth 2.0 // authorization server. trace!("Waiting for the OAuth 2.0 authorization server to give us the access token."); - if let Err(e) = self.wait_for_tokens(&auth_grant_response).await { + if let Err(e) = self.wait_for_tokens(&server_metadata, &auth_grant_response).await { // If we received an error, and it's one of the ones we should report to the // other side, do so now. if let Some(e) = e.as_request_token_error() { @@ -282,29 +289,37 @@ impl<'a> LoginWithQrCode<'a> { } /// Register the client with the OAuth 2.0 authorization server. - async fn register_client(&self) -> Result<(), DeviceAuthorizationOAuthError> { + /// + /// Returns the authorization server metadata. + async fn register_client( + &self, + ) -> Result { let oauth = self.client.oauth(); - oauth.use_registration_method(&self.registration_method).await?; + let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?; + oauth.use_registration_method(&server_metadata, &self.registration_method).await?; - Ok(()) + Ok(server_metadata) } async fn request_device_authorization( &self, + server_metadata: &AuthorizationServerMetadata, device_id: Curve25519PublicKey, ) -> Result { let oauth = self.client.oauth(); - let response = - oauth.request_device_authorization(Some(device_id.to_base64().into())).await?; + let response = oauth + .request_device_authorization(server_metadata, Some(device_id.to_base64().into())) + .await?; Ok(response) } async fn wait_for_tokens( &self, + server_metadata: &AuthorizationServerMetadata, auth_response: &StandardDeviceAuthorizationResponse, ) -> Result<(), DeviceAuthorizationOAuthError> { let oauth = self.client.oauth(); - oauth.exchange_device_code(auth_response).await?; + oauth.exchange_device_code(server_metadata, auth_response).await?; Ok(()) } } @@ -414,7 +429,7 @@ mod test { let (sender, receiver) = tokio::sync::oneshot::channel(); let oauth_server = server.oauth(); - oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await; oauth_server.mock_registration().ok().expect(1).named("registration").mount().await; oauth_server .mock_device_authorization() @@ -496,7 +511,7 @@ mod test { let (sender, receiver) = tokio::sync::oneshot::channel(); let oauth_server = server.oauth(); - oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await; oauth_server.mock_registration().ok().expect(1).named("registration").mount().await; oauth_server .mock_device_authorization() @@ -631,7 +646,7 @@ mod test { oauth_server .mock_server_metadata() .ok_without_device_authorization() - .expect(1..) + .expect(1) .named("server_metadata") .mount() .await; diff --git a/crates/matrix-sdk/src/authentication/oauth/tests.rs b/crates/matrix-sdk/src/authentication/oauth/tests.rs index 37c243dc6..d1af89a47 100644 --- a/crates/matrix-sdk/src/authentication/oauth/tests.rs +++ b/crates/matrix-sdk/src/authentication/oauth/tests.rs @@ -46,7 +46,7 @@ async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAu server.mock_who_am_i().ok().named("whoami").mount().await; let oauth_server = server.oauth(); - oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await; oauth_server.mock_registration().ok().expect(1).named("registration").mount().await; oauth_server.mock_token().ok().mount().await; @@ -275,7 +275,7 @@ async fn test_login_url() -> anyhow::Result<()> { let issuer = Url::parse(&server.server().uri())?; let oauth_server = server.oauth(); - oauth_server.mock_server_metadata().ok().expect(1..).mount().await; + oauth_server.mock_server_metadata().ok().expect(3).mount().await; let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await; let oauth = client.oauth(); @@ -361,9 +361,8 @@ fn test_authorization_response() -> anyhow::Result<()> { #[async_test] async fn test_finish_login() -> anyhow::Result<()> { let server = MatrixMockServer::new().await; - let oauth_server = server.oauth(); - oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + let server_metadata = oauth_server.server_metadata(); let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await; let oauth = client.oauth(); @@ -383,6 +382,7 @@ async fn test_finish_login() -> anyhow::Result<()> { let redirect_uri = REDIRECT_URI_STRING; let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let auth_validation_data = AuthorizationValidationData { + server_metadata: server_metadata.clone(), device_id: owned_device_id!("D3V1C31D"), redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?, pkce_verifier, @@ -435,6 +435,7 @@ async fn test_finish_login() -> anyhow::Result<()> { let redirect_uri = REDIRECT_URI_STRING; let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let auth_validation_data = AuthorizationValidationData { + server_metadata: server_metadata.clone(), device_id: owned_device_id!("D3V1C31D"), redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?, pkce_verifier, @@ -478,6 +479,7 @@ async fn test_finish_login() -> anyhow::Result<()> { let redirect_uri = REDIRECT_URI_STRING; let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let auth_validation_data = AuthorizationValidationData { + server_metadata, device_id: wrong_device_id.to_owned(), redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?, pkce_verifier, diff --git a/crates/matrix-sdk/src/test_utils/mocks/oauth.rs b/crates/matrix-sdk/src/test_utils/mocks/oauth.rs index eb833c915..791766d2f 100644 --- a/crates/matrix-sdk/src/test_utils/mocks/oauth.rs +++ b/crates/matrix-sdk/src/test_utils/mocks/oauth.rs @@ -66,6 +66,14 @@ impl<'a> OAuthMockServer<'a> { fn mock_endpoint(&self, mock: MockBuilder, endpoint: T) -> MockEndpoint<'a, T> { self.server.mock_endpoint(mock, endpoint) } + + /// Get the mock OAuth 2.0 server metadata. + pub fn server_metadata(&self) -> AuthorizationServerMetadata { + MockServerMetadataBuilder::new(&self.server.server().uri()) + .build() + .deserialize() + .expect("mock OAuth 2.0 server metadata should deserialize successfully") + } } // Specific mount endpoints.