From abc4fbc2f7aba3bc4b95fffdd3ce736ff2ee4555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Thu, 6 Mar 2025 16:10:54 +0100 Subject: [PATCH] refactor(oidc): Import code to discover OIDC provider configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Commaille --- .../src/authentication/oidc/error.rs | 12 +- .../src/authentication/oidc/http_client.rs | 191 ++++++++++++++++++ .../matrix-sdk/src/authentication/oidc/mod.rs | 117 ++++------- .../src/authentication/oidc/oidc_discovery.rs | 64 ++++++ 4 files changed, 304 insertions(+), 80 deletions(-) create mode 100644 crates/matrix-sdk/src/authentication/oidc/http_client.rs create mode 100644 crates/matrix-sdk/src/authentication/oidc/oidc_discovery.rs diff --git a/crates/matrix-sdk/src/authentication/oidc/error.rs b/crates/matrix-sdk/src/authentication/oidc/error.rs index 0ed8c892e..ca2b37b6e 100644 --- a/crates/matrix-sdk/src/authentication/oidc/error.rs +++ b/crates/matrix-sdk/src/authentication/oidc/error.rs @@ -15,6 +15,7 @@ //! Error types used in the [`Oidc`](super::Oidc) API. pub use mas_oidc_client::error::*; +use mas_oidc_client::types::oidc::ProviderMetadataVerificationError; use matrix_sdk_base::deserialized_responses::PrivOwnedStr; use oauth2::ErrorResponseType; pub use oauth2::{ @@ -131,9 +132,18 @@ pub enum OauthDiscoveryError { #[error(transparent)] Json(#[from] serde_json::Error), + /// The server metadata is incomplete or insecure. + #[error(transparent)] + Validation(#[from] ProviderMetadataVerificationError), + + /// An error occurred when building the OpenID Connect provider + /// configuration URL. + #[error(transparent)] + Url(#[from] url::ParseError), + /// An error occurred when making a request to the OpenID Connect provider. #[error(transparent)] - Oidc(#[from] DiscoveryError), + Oidc(#[from] OauthRequestError), } impl OauthDiscoveryError { diff --git a/crates/matrix-sdk/src/authentication/oidc/http_client.rs b/crates/matrix-sdk/src/authentication/oidc/http_client.rs new file mode 100644 index 000000000..fa3a95dee --- /dev/null +++ b/crates/matrix-sdk/src/authentication/oidc/http_client.rs @@ -0,0 +1,191 @@ +// Copyright 2025 Kévin Commaille +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! HTTP client and helpers for making OAuth 2.0 requests. + +use std::{future::Future, pin::Pin}; + +use oauth2::{ + AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError, +}; + +/// An HTTP client for making OAuth 2.0 requests. +#[derive(Debug, Clone)] +pub(super) struct OauthHttpClient { + pub(super) inner: reqwest::Client, + /// Rewrite HTTPS requests to use HTTP instead. + /// + /// This is a workaround to bypass some checks that require an HTTPS URL, + /// but we can only mock HTTP URLs. + #[cfg(test)] + pub(super) insecure_rewrite_https_to_http: bool, +} + +impl<'c> AsyncHttpClient<'c> for OauthHttpClient { + type Error = HttpClientError; + + type Future = + Pin> + Send + Sync + 'c>>; + + fn call(&'c self, request: HttpRequest) -> Self::Future { + Box::pin(async move { + #[cfg(test)] + let request = if self.insecure_rewrite_https_to_http + && request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS) + { + let mut request = request; + + let mut uri_parts = request.uri().clone().into_parts(); + uri_parts.scheme = Some(http::uri::Scheme::HTTP); + *request.uri_mut() = http::uri::Uri::from_parts(uri_parts) + .expect("reconstructing URI from parts should work"); + + request + } else { + request + }; + + let response = self.inner.call(request).await?; + + Ok(response) + }) + } +} + +/// Check the status code of the given HTTP response to identify errors. +pub(super) fn check_http_response_status_code( + http_response: &HttpResponse, +) -> Result<(), RequestTokenError, T>> { + if http_response.status().as_u16() < 400 { + return Ok(()); + } + + let reason = http_response.body().as_slice(); + let error = if reason.is_empty() { + RequestTokenError::Other("server returned an empty error response".to_owned()) + } else { + match serde_json::from_slice(reason) { + Ok(error) => RequestTokenError::ServerResponse(error), + Err(error) => RequestTokenError::Other(error.to_string()), + } + }; + + Err(error) +} + +/// Check that the server returned a response with a JSON `Content-Type`. +pub(super) fn check_http_response_json_content_type( + http_response: &HttpResponse, +) -> Result<(), RequestTokenError, T>> { + let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else { + return Ok(()); + }; + + if content_type + .to_str() + // Check only the beginning of the content type, because there might be extra + // parameters, like a charset. + .is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str())) + { + Ok(()) + } else { + Err(RequestTokenError::Other(format!( + "unexpected response Content-Type: {content_type:?}, should be `{}`", + mime::APPLICATION_JSON + ))) + } +} + +#[cfg(test)] +mod tests { + use assert_matches2::assert_matches; + use oauth2::{basic::BasicErrorResponse, RequestTokenError}; + + use super::{check_http_response_json_content_type, check_http_response_status_code}; + + #[test] + fn test_check_http_response_status_code() { + // OK + let response = http::Response::builder().status(200).body(Vec::::new()).unwrap(); + assert_matches!(check_http_response_status_code::(&response), Ok(())); + + // Error without body. + let response = http::Response::builder().status(404).body(Vec::::new()).unwrap(); + assert_matches!( + check_http_response_status_code::(&response), + Err(RequestTokenError::Other(_)) + ); + + // Error with invalid body. + let response = + http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap(); + assert_matches!( + check_http_response_status_code::(&response), + Err(RequestTokenError::Other(_)) + ); + + // Error with valid body. + let response = http::Response::builder() + .status(404) + .body(br#"{"error": "invalid_request"}"#.to_vec()) + .unwrap(); + assert_matches!( + check_http_response_status_code::(&response), + Err(RequestTokenError::ServerResponse(_)) + ); + } + + #[test] + fn test_check_http_response_json_content_type() { + // Valid content type. + let response = http::Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, "application/json") + .body(b"{}".to_vec()) + .unwrap(); + assert_matches!( + check_http_response_json_content_type::(&response), + Ok(()) + ); + + // Valid content type with charset. + let response = http::Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, "application/json; charset=utf-8") + .body(b"{}".to_vec()) + .unwrap(); + assert_matches!( + check_http_response_json_content_type::(&response), + Ok(()) + ); + + // Without content type. + let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap(); + assert_matches!( + check_http_response_json_content_type::(&response), + Ok(()) + ); + + // Wrong content type. + let response = http::Response::builder() + .status(200) + .header(http::header::CONTENT_TYPE, "text/html") + .body(b"

HTML!

".to_vec()) + .unwrap(); + assert_matches!( + check_http_response_json_content_type::(&response), + Err(RequestTokenError::Other(_)) + ); + } +} diff --git a/crates/matrix-sdk/src/authentication/oidc/mod.rs b/crates/matrix-sdk/src/authentication/oidc/mod.rs index 07b3aa01e..74be213fa 100644 --- a/crates/matrix-sdk/src/authentication/oidc/mod.rs +++ b/crates/matrix-sdk/src/authentication/oidc/mod.rs @@ -146,7 +146,7 @@ //! [`AuthenticateError::InsufficientScope`]: ruma::api::client::error::AuthenticateError //! [`examples/oidc_cli`]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/examples/oidc_cli -use std::{borrow::Cow, collections::HashMap, fmt, future::Future, pin::Pin, sync::Arc}; +use std::{borrow::Cow, collections::HashMap, fmt, sync::Arc}; use as_variant::as_variant; use error::{ @@ -155,10 +155,7 @@ use error::{ }; use mas_oidc_client::{ http_service::HttpService, - requests::{ - discovery::{discover, insecure_discover}, - registration::register_client, - }, + requests::registration::register_client, types::{ oidc::{ AccountManagementAction, ProviderMetadata, ProviderMetadataVerificationError, @@ -173,15 +170,18 @@ use matrix_sdk_base::crypto::types::qr_login::QrCodeData; use matrix_sdk_base::{once_cell::sync::OnceCell, SessionMeta}; pub use oauth2::CsrfToken; use oauth2::{ - basic::BasicClient as OauthClient, AccessToken, AsyncHttpClient, HttpRequest, HttpResponse, - PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, Scope, StandardErrorResponse, - StandardRevocableToken, TokenResponse, TokenUrl, + basic::BasicClient as OauthClient, AccessToken, PkceCodeVerifier, RedirectUrl, RefreshToken, + RevocationUrl, Scope, StandardErrorResponse, StandardRevocableToken, TokenResponse, TokenUrl, }; use ruma::{ api::client::discovery::{ get_authentication_issuer, - get_authorization_server_metadata::{self, msc2965::Prompt}, + get_authorization_server_metadata::{ + self, + msc2965::{AuthorizationServerMetadata, Prompt}, + }, }, + serde::Raw, DeviceId, OwnedDeviceId, }; use serde::{Deserialize, Serialize}; @@ -194,6 +194,8 @@ mod account_management_url; mod auth_code_builder; mod cross_process; pub mod error; +mod http_client; +mod oidc_discovery; #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] pub mod qrcode; pub mod registrations; @@ -203,6 +205,8 @@ mod tests; use self::{ account_management_url::build_account_management_url, cross_process::{CrossProcessRefreshLockGuard, CrossProcessRefreshManager}, + http_client::OauthHttpClient, + oidc_discovery::discover, qrcode::LoginWithQrCode, registrations::{ClientId, OidcRegistrations}, }; @@ -677,8 +681,7 @@ impl Oidc { /// MSC2956: https://github.com/matrix-org/matrix-spec-proposals/pull/2965 async fn fallback_discover( &self, - insecure: bool, - ) -> Result { + ) -> Result, OauthDiscoveryError> { #[allow(deprecated)] let issuer = match self.client.send(get_authentication_issuer::msc2965::Request::new()).await { @@ -693,11 +696,7 @@ impl Oidc { Err(error) => return Err(error.into()), }; - if insecure { - insecure_discover(&self.http_service(), &issuer).await.map_err(Into::into) - } else { - discover(&self.http_service(), &issuer).await.map_err(Into::into) - } + discover(self.http_client(), &issuer).await } /// Fetch the OAuth 2.0 server metadata of the homeserver. @@ -711,33 +710,35 @@ impl Oidc { .is_some_and(|err| err.status_code == http::StatusCode::NOT_FOUND) }; - match self.client.send(get_authorization_server_metadata::msc2965::Request::new()).await { - Ok(response) => { - let metadata = response.metadata.deserialize_as::()?; - - let result = if self.ctx().insecure_discover { - metadata.insecure_verify_metadata() - } else { - // The mas-oidc-client method needs to compare the issuer for validation. It's a - // bit unnecessary because we take it from the metadata, oh well. - let issuer = - metadata.issuer.clone().ok_or(error::DiscoveryError::Validation( - ProviderMetadataVerificationError::MissingIssuer, - ))?; - metadata.validate(&issuer) - }; - - Ok(result.map_err(error::DiscoveryError::Validation)?) - } + let raw_metadata = match self + .client + .send(get_authorization_server_metadata::msc2965::Request::new()) + .await + { + Ok(response) => response.metadata, // If the endpoint returns a 404, i.e. the server doesn't support the endpoint, attempt // to use the equivalent, but deprecated, endpoint. Err(error) if is_endpoint_unsupported(&error) => { // TODO: remove this fallback behavior when the metadata endpoint has wider // support. - self.fallback_discover(self.ctx().insecure_discover).await + self.fallback_discover().await? } - Err(error) => Err(error.into()), - } + Err(error) => return Err(error.into()), + }; + + let metadata = raw_metadata.deserialize_as::()?; + + let metadata = if self.ctx().insecure_discover { + metadata.insecure_verify_metadata()? + } else { + // The mas-oidc-client method needs to compare the issuer for validation. It's a + // bit unnecessary because we take it from the metadata, oh well. + let issuer = + metadata.issuer.clone().ok_or(ProviderMetadataVerificationError::MissingIssuer)?; + metadata.validate(&issuer)? + }; + + Ok(metadata) } /// The OpenID Connect unique identifier of this client obtained after @@ -1613,45 +1614,3 @@ pub struct AuthorizationError { fn hash_str(x: &str) -> impl fmt::LowerHex { sha2::Sha256::new().chain_update(x).finalize() } - -#[derive(Debug, Clone)] -struct OauthHttpClient { - inner: reqwest::Client, - /// Rewrite HTTPS requests to use HTTP instead. - /// - /// This is a workaround to bypass some checks that require an HTTPS URL, - /// but we can only mock HTTP URLs. - #[cfg(test)] - insecure_rewrite_https_to_http: bool, -} - -impl<'c> AsyncHttpClient<'c> for OauthHttpClient { - type Error = error::HttpClientError; - - type Future = - Pin> + Send + Sync + 'c>>; - - fn call(&'c self, request: HttpRequest) -> Self::Future { - Box::pin(async move { - #[cfg(test)] - let request = if self.insecure_rewrite_https_to_http - && request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS) - { - let mut request = request; - - let mut uri_parts = request.uri().clone().into_parts(); - uri_parts.scheme = Some(http::uri::Scheme::HTTP); - *request.uri_mut() = http::uri::Uri::from_parts(uri_parts) - .expect("reconstructing URI from parts should work"); - - request - } else { - request - }; - - let response = self.inner.call(request).await?; - - Ok(response) - }) - } -} diff --git a/crates/matrix-sdk/src/authentication/oidc/oidc_discovery.rs b/crates/matrix-sdk/src/authentication/oidc/oidc_discovery.rs new file mode 100644 index 000000000..ef5aefd27 --- /dev/null +++ b/crates/matrix-sdk/src/authentication/oidc/oidc_discovery.rs @@ -0,0 +1,64 @@ +// Copyright 2025 Kévin Commaille +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Requests for [OpenID Connect Provider Discovery]. +//! +//! [OpenID Connect Provider Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html + +use oauth2::{AsyncHttpClient, HttpClientError, RequestTokenError}; +use ruma::{ + api::client::discovery::get_authorization_server_metadata::msc2965::AuthorizationServerMetadata, + serde::Raw, +}; +use url::Url; + +use super::{ + error::OauthDiscoveryError, + http_client::{check_http_response_json_content_type, check_http_response_status_code}, + OauthHttpClient, +}; + +/// Fetch the OpenID Connect provider metadata. +pub(super) async fn discover( + http_client: &OauthHttpClient, + issuer: &str, +) -> Result, OauthDiscoveryError> { + tracing::debug!("Fetching provider metadata..."); + + let mut url = Url::parse(issuer)?; + + // If the path doesn't end with a slash, the last segment is removed when + // using `join`. + if !url.path().ends_with('/') { + let mut path = url.path().to_owned(); + path.push('/'); + url.set_path(&path); + } + + let config_url = url.join(".well-known/openid-configuration")?; + + let request = http::Request::get(config_url.as_str()) + .body(Vec::new()) + .map_err(|err| RequestTokenError::Request(HttpClientError::Http(err)))?; + let response = http_client.call(request).await.map_err(RequestTokenError::Request)?; + + check_http_response_status_code(&response)?; + check_http_response_json_content_type(&response)?; + + let metadata = serde_json::from_slice(&response.into_body())?; + + tracing::debug!(?metadata); + + Ok(metadata) +}