mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-11 01:13:14 -04:00
refactor(oidc): Import code to discover OIDC provider configuration
Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
committed by
Ivan Enderlin
parent
9adff21f78
commit
abc4fbc2f7
@@ -15,6 +15,7 @@
|
||||
//! Error types used in the [`Oidc`](super::Oidc) API.
|
||||
|
||||
pub use mas_oidc_client::error::*;
|
||||
use mas_oidc_client::types::oidc::ProviderMetadataVerificationError;
|
||||
use matrix_sdk_base::deserialized_responses::PrivOwnedStr;
|
||||
use oauth2::ErrorResponseType;
|
||||
pub use oauth2::{
|
||||
@@ -131,9 +132,18 @@ pub enum OauthDiscoveryError {
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// The server metadata is incomplete or insecure.
|
||||
#[error(transparent)]
|
||||
Validation(#[from] ProviderMetadataVerificationError),
|
||||
|
||||
/// An error occurred when building the OpenID Connect provider
|
||||
/// configuration URL.
|
||||
#[error(transparent)]
|
||||
Url(#[from] url::ParseError),
|
||||
|
||||
/// An error occurred when making a request to the OpenID Connect provider.
|
||||
#[error(transparent)]
|
||||
Oidc(#[from] DiscoveryError),
|
||||
Oidc(#[from] OauthRequestError<BasicErrorResponseType>),
|
||||
}
|
||||
|
||||
impl OauthDiscoveryError {
|
||||
|
||||
191
crates/matrix-sdk/src/authentication/oidc/http_client.rs
Normal file
191
crates/matrix-sdk/src/authentication/oidc/http_client.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
// Copyright 2025 Kévin Commaille
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! HTTP client and helpers for making OAuth 2.0 requests.
|
||||
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use oauth2::{
|
||||
AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError,
|
||||
};
|
||||
|
||||
/// An HTTP client for making OAuth 2.0 requests.
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct OauthHttpClient {
|
||||
pub(super) inner: reqwest::Client,
|
||||
/// Rewrite HTTPS requests to use HTTP instead.
|
||||
///
|
||||
/// This is a workaround to bypass some checks that require an HTTPS URL,
|
||||
/// but we can only mock HTTP URLs.
|
||||
#[cfg(test)]
|
||||
pub(super) insecure_rewrite_https_to_http: bool,
|
||||
}
|
||||
|
||||
impl<'c> AsyncHttpClient<'c> for OauthHttpClient {
|
||||
type Error = HttpClientError<reqwest::Error>;
|
||||
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
|
||||
|
||||
fn call(&'c self, request: HttpRequest) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
#[cfg(test)]
|
||||
let request = if self.insecure_rewrite_https_to_http
|
||||
&& request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
|
||||
{
|
||||
let mut request = request;
|
||||
|
||||
let mut uri_parts = request.uri().clone().into_parts();
|
||||
uri_parts.scheme = Some(http::uri::Scheme::HTTP);
|
||||
*request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
|
||||
.expect("reconstructing URI from parts should work");
|
||||
|
||||
request
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
||||
let response = self.inner.call(request).await?;
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Check the status code of the given HTTP response to identify errors.
|
||||
pub(super) fn check_http_response_status_code<T: ErrorResponse + 'static>(
|
||||
http_response: &HttpResponse,
|
||||
) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
|
||||
if http_response.status().as_u16() < 400 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let reason = http_response.body().as_slice();
|
||||
let error = if reason.is_empty() {
|
||||
RequestTokenError::Other("server returned an empty error response".to_owned())
|
||||
} else {
|
||||
match serde_json::from_slice(reason) {
|
||||
Ok(error) => RequestTokenError::ServerResponse(error),
|
||||
Err(error) => RequestTokenError::Other(error.to_string()),
|
||||
}
|
||||
};
|
||||
|
||||
Err(error)
|
||||
}
|
||||
|
||||
/// Check that the server returned a response with a JSON `Content-Type`.
|
||||
pub(super) fn check_http_response_json_content_type<T: ErrorResponse + 'static>(
|
||||
http_response: &HttpResponse,
|
||||
) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
|
||||
let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if content_type
|
||||
.to_str()
|
||||
// Check only the beginning of the content type, because there might be extra
|
||||
// parameters, like a charset.
|
||||
.is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str()))
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(RequestTokenError::Other(format!(
|
||||
"unexpected response Content-Type: {content_type:?}, should be `{}`",
|
||||
mime::APPLICATION_JSON
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use assert_matches2::assert_matches;
|
||||
use oauth2::{basic::BasicErrorResponse, RequestTokenError};
|
||||
|
||||
use super::{check_http_response_json_content_type, check_http_response_status_code};
|
||||
|
||||
#[test]
|
||||
fn test_check_http_response_status_code() {
|
||||
// OK
|
||||
let response = http::Response::builder().status(200).body(Vec::<u8>::new()).unwrap();
|
||||
assert_matches!(check_http_response_status_code::<BasicErrorResponse>(&response), Ok(()));
|
||||
|
||||
// Error without body.
|
||||
let response = http::Response::builder().status(404).body(Vec::<u8>::new()).unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_status_code::<BasicErrorResponse>(&response),
|
||||
Err(RequestTokenError::Other(_))
|
||||
);
|
||||
|
||||
// Error with invalid body.
|
||||
let response =
|
||||
http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_status_code::<BasicErrorResponse>(&response),
|
||||
Err(RequestTokenError::Other(_))
|
||||
);
|
||||
|
||||
// Error with valid body.
|
||||
let response = http::Response::builder()
|
||||
.status(404)
|
||||
.body(br#"{"error": "invalid_request"}"#.to_vec())
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_status_code::<BasicErrorResponse>(&response),
|
||||
Err(RequestTokenError::ServerResponse(_))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_http_response_json_content_type() {
|
||||
// Valid content type.
|
||||
let response = http::Response::builder()
|
||||
.status(200)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(b"{}".to_vec())
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_json_content_type::<BasicErrorResponse>(&response),
|
||||
Ok(())
|
||||
);
|
||||
|
||||
// Valid content type with charset.
|
||||
let response = http::Response::builder()
|
||||
.status(200)
|
||||
.header(http::header::CONTENT_TYPE, "application/json; charset=utf-8")
|
||||
.body(b"{}".to_vec())
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_json_content_type::<BasicErrorResponse>(&response),
|
||||
Ok(())
|
||||
);
|
||||
|
||||
// Without content type.
|
||||
let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_json_content_type::<BasicErrorResponse>(&response),
|
||||
Ok(())
|
||||
);
|
||||
|
||||
// Wrong content type.
|
||||
let response = http::Response::builder()
|
||||
.status(200)
|
||||
.header(http::header::CONTENT_TYPE, "text/html")
|
||||
.body(b"<html><body><h1>HTML!</h1></body></html>".to_vec())
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
check_http_response_json_content_type::<BasicErrorResponse>(&response),
|
||||
Err(RequestTokenError::Other(_))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -146,7 +146,7 @@
|
||||
//! [`AuthenticateError::InsufficientScope`]: ruma::api::client::error::AuthenticateError
|
||||
//! [`examples/oidc_cli`]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/examples/oidc_cli
|
||||
|
||||
use std::{borrow::Cow, collections::HashMap, fmt, future::Future, pin::Pin, sync::Arc};
|
||||
use std::{borrow::Cow, collections::HashMap, fmt, sync::Arc};
|
||||
|
||||
use as_variant::as_variant;
|
||||
use error::{
|
||||
@@ -155,10 +155,7 @@ use error::{
|
||||
};
|
||||
use mas_oidc_client::{
|
||||
http_service::HttpService,
|
||||
requests::{
|
||||
discovery::{discover, insecure_discover},
|
||||
registration::register_client,
|
||||
},
|
||||
requests::registration::register_client,
|
||||
types::{
|
||||
oidc::{
|
||||
AccountManagementAction, ProviderMetadata, ProviderMetadataVerificationError,
|
||||
@@ -173,15 +170,18 @@ use matrix_sdk_base::crypto::types::qr_login::QrCodeData;
|
||||
use matrix_sdk_base::{once_cell::sync::OnceCell, SessionMeta};
|
||||
pub use oauth2::CsrfToken;
|
||||
use oauth2::{
|
||||
basic::BasicClient as OauthClient, AccessToken, AsyncHttpClient, HttpRequest, HttpResponse,
|
||||
PkceCodeVerifier, RedirectUrl, RefreshToken, RevocationUrl, Scope, StandardErrorResponse,
|
||||
StandardRevocableToken, TokenResponse, TokenUrl,
|
||||
basic::BasicClient as OauthClient, AccessToken, PkceCodeVerifier, RedirectUrl, RefreshToken,
|
||||
RevocationUrl, Scope, StandardErrorResponse, StandardRevocableToken, TokenResponse, TokenUrl,
|
||||
};
|
||||
use ruma::{
|
||||
api::client::discovery::{
|
||||
get_authentication_issuer,
|
||||
get_authorization_server_metadata::{self, msc2965::Prompt},
|
||||
get_authorization_server_metadata::{
|
||||
self,
|
||||
msc2965::{AuthorizationServerMetadata, Prompt},
|
||||
},
|
||||
},
|
||||
serde::Raw,
|
||||
DeviceId, OwnedDeviceId,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -194,6 +194,8 @@ mod account_management_url;
|
||||
mod auth_code_builder;
|
||||
mod cross_process;
|
||||
pub mod error;
|
||||
mod http_client;
|
||||
mod oidc_discovery;
|
||||
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
|
||||
pub mod qrcode;
|
||||
pub mod registrations;
|
||||
@@ -203,6 +205,8 @@ mod tests;
|
||||
use self::{
|
||||
account_management_url::build_account_management_url,
|
||||
cross_process::{CrossProcessRefreshLockGuard, CrossProcessRefreshManager},
|
||||
http_client::OauthHttpClient,
|
||||
oidc_discovery::discover,
|
||||
qrcode::LoginWithQrCode,
|
||||
registrations::{ClientId, OidcRegistrations},
|
||||
};
|
||||
@@ -677,8 +681,7 @@ impl Oidc {
|
||||
/// MSC2956: https://github.com/matrix-org/matrix-spec-proposals/pull/2965
|
||||
async fn fallback_discover(
|
||||
&self,
|
||||
insecure: bool,
|
||||
) -> Result<VerifiedProviderMetadata, OauthDiscoveryError> {
|
||||
) -> Result<Raw<AuthorizationServerMetadata>, OauthDiscoveryError> {
|
||||
#[allow(deprecated)]
|
||||
let issuer =
|
||||
match self.client.send(get_authentication_issuer::msc2965::Request::new()).await {
|
||||
@@ -693,11 +696,7 @@ impl Oidc {
|
||||
Err(error) => return Err(error.into()),
|
||||
};
|
||||
|
||||
if insecure {
|
||||
insecure_discover(&self.http_service(), &issuer).await.map_err(Into::into)
|
||||
} else {
|
||||
discover(&self.http_service(), &issuer).await.map_err(Into::into)
|
||||
}
|
||||
discover(self.http_client(), &issuer).await
|
||||
}
|
||||
|
||||
/// Fetch the OAuth 2.0 server metadata of the homeserver.
|
||||
@@ -711,33 +710,35 @@ impl Oidc {
|
||||
.is_some_and(|err| err.status_code == http::StatusCode::NOT_FOUND)
|
||||
};
|
||||
|
||||
match self.client.send(get_authorization_server_metadata::msc2965::Request::new()).await {
|
||||
Ok(response) => {
|
||||
let metadata = response.metadata.deserialize_as::<ProviderMetadata>()?;
|
||||
|
||||
let result = if self.ctx().insecure_discover {
|
||||
metadata.insecure_verify_metadata()
|
||||
} else {
|
||||
// The mas-oidc-client method needs to compare the issuer for validation. It's a
|
||||
// bit unnecessary because we take it from the metadata, oh well.
|
||||
let issuer =
|
||||
metadata.issuer.clone().ok_or(error::DiscoveryError::Validation(
|
||||
ProviderMetadataVerificationError::MissingIssuer,
|
||||
))?;
|
||||
metadata.validate(&issuer)
|
||||
};
|
||||
|
||||
Ok(result.map_err(error::DiscoveryError::Validation)?)
|
||||
}
|
||||
let raw_metadata = match self
|
||||
.client
|
||||
.send(get_authorization_server_metadata::msc2965::Request::new())
|
||||
.await
|
||||
{
|
||||
Ok(response) => response.metadata,
|
||||
// If the endpoint returns a 404, i.e. the server doesn't support the endpoint, attempt
|
||||
// to use the equivalent, but deprecated, endpoint.
|
||||
Err(error) if is_endpoint_unsupported(&error) => {
|
||||
// TODO: remove this fallback behavior when the metadata endpoint has wider
|
||||
// support.
|
||||
self.fallback_discover(self.ctx().insecure_discover).await
|
||||
self.fallback_discover().await?
|
||||
}
|
||||
Err(error) => Err(error.into()),
|
||||
}
|
||||
Err(error) => return Err(error.into()),
|
||||
};
|
||||
|
||||
let metadata = raw_metadata.deserialize_as::<ProviderMetadata>()?;
|
||||
|
||||
let metadata = if self.ctx().insecure_discover {
|
||||
metadata.insecure_verify_metadata()?
|
||||
} else {
|
||||
// The mas-oidc-client method needs to compare the issuer for validation. It's a
|
||||
// bit unnecessary because we take it from the metadata, oh well.
|
||||
let issuer =
|
||||
metadata.issuer.clone().ok_or(ProviderMetadataVerificationError::MissingIssuer)?;
|
||||
metadata.validate(&issuer)?
|
||||
};
|
||||
|
||||
Ok(metadata)
|
||||
}
|
||||
|
||||
/// The OpenID Connect unique identifier of this client obtained after
|
||||
@@ -1613,45 +1614,3 @@ pub struct AuthorizationError {
|
||||
fn hash_str(x: &str) -> impl fmt::LowerHex {
|
||||
sha2::Sha256::new().chain_update(x).finalize()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OauthHttpClient {
|
||||
inner: reqwest::Client,
|
||||
/// Rewrite HTTPS requests to use HTTP instead.
|
||||
///
|
||||
/// This is a workaround to bypass some checks that require an HTTPS URL,
|
||||
/// but we can only mock HTTP URLs.
|
||||
#[cfg(test)]
|
||||
insecure_rewrite_https_to_http: bool,
|
||||
}
|
||||
|
||||
impl<'c> AsyncHttpClient<'c> for OauthHttpClient {
|
||||
type Error = error::HttpClientError<reqwest::Error>;
|
||||
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
|
||||
|
||||
fn call(&'c self, request: HttpRequest) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
#[cfg(test)]
|
||||
let request = if self.insecure_rewrite_https_to_http
|
||||
&& request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
|
||||
{
|
||||
let mut request = request;
|
||||
|
||||
let mut uri_parts = request.uri().clone().into_parts();
|
||||
uri_parts.scheme = Some(http::uri::Scheme::HTTP);
|
||||
*request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
|
||||
.expect("reconstructing URI from parts should work");
|
||||
|
||||
request
|
||||
} else {
|
||||
request
|
||||
};
|
||||
|
||||
let response = self.inner.call(request).await?;
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
64
crates/matrix-sdk/src/authentication/oidc/oidc_discovery.rs
Normal file
64
crates/matrix-sdk/src/authentication/oidc/oidc_discovery.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2025 Kévin Commaille
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Requests for [OpenID Connect Provider Discovery].
|
||||
//!
|
||||
//! [OpenID Connect Provider Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
|
||||
use oauth2::{AsyncHttpClient, HttpClientError, RequestTokenError};
|
||||
use ruma::{
|
||||
api::client::discovery::get_authorization_server_metadata::msc2965::AuthorizationServerMetadata,
|
||||
serde::Raw,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::{
|
||||
error::OauthDiscoveryError,
|
||||
http_client::{check_http_response_json_content_type, check_http_response_status_code},
|
||||
OauthHttpClient,
|
||||
};
|
||||
|
||||
/// Fetch the OpenID Connect provider metadata.
|
||||
pub(super) async fn discover(
|
||||
http_client: &OauthHttpClient,
|
||||
issuer: &str,
|
||||
) -> Result<Raw<AuthorizationServerMetadata>, OauthDiscoveryError> {
|
||||
tracing::debug!("Fetching provider metadata...");
|
||||
|
||||
let mut url = Url::parse(issuer)?;
|
||||
|
||||
// If the path doesn't end with a slash, the last segment is removed when
|
||||
// using `join`.
|
||||
if !url.path().ends_with('/') {
|
||||
let mut path = url.path().to_owned();
|
||||
path.push('/');
|
||||
url.set_path(&path);
|
||||
}
|
||||
|
||||
let config_url = url.join(".well-known/openid-configuration")?;
|
||||
|
||||
let request = http::Request::get(config_url.as_str())
|
||||
.body(Vec::new())
|
||||
.map_err(|err| RequestTokenError::Request(HttpClientError::Http(err)))?;
|
||||
let response = http_client.call(request).await.map_err(RequestTokenError::Request)?;
|
||||
|
||||
check_http_response_status_code(&response)?;
|
||||
check_http_response_json_content_type(&response)?;
|
||||
|
||||
let metadata = serde_json::from_slice(&response.into_body())?;
|
||||
|
||||
tracing::debug!(?metadata);
|
||||
|
||||
Ok(metadata)
|
||||
}
|
||||
Reference in New Issue
Block a user