From 26cb805e0f0628fd53b92daeae135d80932dd07e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Wed, 26 Feb 2025 14:42:47 +0100 Subject: [PATCH] test(oidc): Use MatrixMockServer in the remaining tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Gets rid of the MockImpl for OidcBackend. Signed-off-by: Kévin Commaille --- .../src/authentication/oidc/backend/mock.rs | 209 ------------------ .../src/authentication/oidc/backend/mod.rs | 3 - .../src/authentication/oidc/cross_process.rs | 90 +++----- .../matrix-sdk/src/authentication/oidc/mod.rs | 2 +- .../src/authentication/oidc/tests.rs | 76 +++---- 5 files changed, 64 insertions(+), 316 deletions(-) delete mode 100644 crates/matrix-sdk/src/authentication/oidc/backend/mock.rs diff --git a/crates/matrix-sdk/src/authentication/oidc/backend/mock.rs b/crates/matrix-sdk/src/authentication/oidc/backend/mock.rs deleted file mode 100644 index b2316b4b1..000000000 --- a/crates/matrix-sdk/src/authentication/oidc/backend/mock.rs +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for that specific language governing permissions and -// limitations under the License. - -//! Test implementation of the OIDC backend. - -use std::sync::{Arc, Mutex}; - -use http::StatusCode; -use mas_oidc_client::{ - error::{ - DiscoveryError as OidcDiscoveryError, Error as OidcClientError, ErrorBody as OidcErrorBody, - HttpError as OidcHttpError, TokenRefreshError, TokenRequestError, - }, - requests::authorization_code::AuthorizationValidationData, - types::{ - client_credentials::ClientCredentials, - errors::ClientErrorCode, - iana::oauth::OAuthTokenTypeHint, - oidc::{ProviderMetadataVerificationError, VerifiedProviderMetadata}, - registration::{ClientRegistrationResponse, VerifiedClientMetadata}, - }, -}; -use url::Url; - -use super::{OidcBackend, OidcError, RefreshedSessionTokens}; -use crate::{ - authentication::oidc::{AuthorizationCode, OauthDiscoveryError, OidcSessionTokens}, - test_utils::mocks::oauth::MockServerMetadataBuilder, -}; - -pub(crate) const ISSUER_URL: &str = "https://oidc.example.com/issuer"; -pub(crate) const CLIENT_ID: &str = "test_client_id"; - -#[derive(Debug)] -pub(crate) struct MockImpl { - /// Must be an HTTPS URL. - issuer: String, - - /// The next session tokens that will be returned by a login or refresh. - next_session_tokens: Option, - - /// The next refresh token that's expected for a refresh. - expected_refresh_token: Option, - - /// Number of refreshes that effectively happened. - pub num_refreshes: Arc>, - - /// Tokens that have been revoked with `revoke_token`. - pub revoked_tokens: Arc>>, - - /// Should we only accept insecure flags during discovery? - is_insecure: bool, -} - -impl MockImpl { - pub fn new() -> Self { - Self { - issuer: ISSUER_URL.to_owned(), - next_session_tokens: None, - expected_refresh_token: None, - num_refreshes: Default::default(), - revoked_tokens: Default::default(), - is_insecure: false, - } - } - - pub fn next_session_tokens(mut self, next_session_tokens: OidcSessionTokens) -> Self { - self.next_session_tokens = Some(next_session_tokens); - self - } - - pub fn expected_refresh_token(mut self, refresh_token: String) -> Self { - self.expected_refresh_token = Some(refresh_token); - self - } - - pub fn mark_insecure(mut self) -> Self { - self.is_insecure = true; - self - } -} - -#[async_trait::async_trait] -impl OidcBackend for MockImpl { - async fn discover( - &self, - insecure: bool, - ) -> Result { - if insecure != self.is_insecure { - return Err(OidcDiscoveryError::Validation( - ProviderMetadataVerificationError::UrlNonHttpsScheme( - "mocking backend", - Url::parse(&self.issuer).unwrap(), - ), - ) - .into()); - } - - Ok(MockServerMetadataBuilder::new(&self.issuer) - .build() - .validate(&self.issuer) - .expect("server metadata should pass validation")) - } - - async fn trade_authorization_code_for_tokens( - &self, - _provider_metadata: VerifiedProviderMetadata, - _credentials: ClientCredentials, - _auth_code: AuthorizationCode, - _validation_data: AuthorizationValidationData, - ) -> Result { - Ok(self - .next_session_tokens - .as_ref() - .expect("missing next session tokens in testing") - .clone()) - } - - async fn register_client( - &self, - _registration_endpoint: &Url, - _client_metadata: VerifiedClientMetadata, - _software_statement: Option, - ) -> Result { - Ok(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_id_issued_at: None, - client_secret_expires_at: None, - }) - } - - async fn revoke_token( - &self, - _client_credentials: ClientCredentials, - _revocation_endpoint: &Url, - token: String, - _token_type_hint: Option, - ) -> Result<(), OidcError> { - self.revoked_tokens.lock().unwrap().push(token); - Ok(()) - } - - async fn refresh_access_token( - &self, - _provider_metadata: VerifiedProviderMetadata, - _credentials: ClientCredentials, - refresh_token: String, - ) -> Result { - if Some(refresh_token) != self.expected_refresh_token { - Err(OidcError::Oidc(OidcClientError::TokenRefresh(TokenRefreshError::Token( - TokenRequestError::Http(OidcHttpError { - body: Some(OidcErrorBody { - error: ClientErrorCode::InvalidGrant, - error_description: None, - }), - status: StatusCode::from_u16(400).unwrap(), - }), - )))) - } else { - *self.num_refreshes.lock().unwrap() += 1; - let next_tokens = self.next_session_tokens.clone().expect("missing next tokens"); - Ok(RefreshedSessionTokens { - access_token: next_tokens.access_token, - refresh_token: next_tokens.refresh_token, - }) - } - } - - #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] - async fn request_device_authorization( - &self, - _device_authorization_endpoint: Url, - _client_id: oauth2::ClientId, - _scopes: Vec, - ) -> Result< - oauth2::StandardDeviceAuthorizationResponse, - oauth2::basic::BasicRequestTokenError>, - > { - unimplemented!() - } - - #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] - async fn exchange_device_code( - &self, - _token_endpoint: Url, - _client_id: oauth2::ClientId, - _device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse, - ) -> Result< - OidcSessionTokens, - oauth2::RequestTokenError< - oauth2::HttpClientError, - oauth2::DeviceCodeErrorResponse, - >, - > { - unimplemented!() - } -} diff --git a/crates/matrix-sdk/src/authentication/oidc/backend/mod.rs b/crates/matrix-sdk/src/authentication/oidc/backend/mod.rs index fe659276b..42f6727c2 100644 --- a/crates/matrix-sdk/src/authentication/oidc/backend/mod.rs +++ b/crates/matrix-sdk/src/authentication/oidc/backend/mod.rs @@ -31,9 +31,6 @@ use super::{AuthorizationCode, OauthDiscoveryError, OidcError, OidcSessionTokens pub(crate) mod server; -#[cfg(test)] -pub(crate) mod mock; - pub(super) struct RefreshedSessionTokens { pub access_token: String, pub refresh_token: Option, diff --git a/crates/matrix-sdk/src/authentication/oidc/cross_process.rs b/crates/matrix-sdk/src/authentication/oidc/cross_process.rs index 58a9c1518..baa505f87 100644 --- a/crates/matrix-sdk/src/authentication/oidc/cross_process.rs +++ b/crates/matrix-sdk/src/authentication/oidc/cross_process.rs @@ -251,7 +251,6 @@ pub enum CrossProcessRefreshLockError { #[cfg(all(test, feature = "e2e-encryption"))] mod tests { - use std::sync::Arc; use anyhow::Context as _; use futures_util::future::join_all; @@ -261,12 +260,7 @@ mod tests { use super::compute_session_hash; use crate::{ - authentication::oidc::{ - backend::mock::{MockImpl, ISSUER_URL}, - cross_process::SessionHash, - tests::prev_session_tokens, - Oidc, - }, + authentication::oidc::{cross_process::SessionHash, tests::prev_session_tokens}, test_utils::{ client::{ oauth::{mock_session, mock_session_tokens}, @@ -302,7 +296,13 @@ mod tests { )?; let session_hash = compute_session_hash(&tokens); - client.oidc().restore_session(mock_session(tokens.clone(), ISSUER_URL.to_owned())).await?; + client + .oidc() + .restore_session(mock_session( + tokens.clone(), + "https://oidc.example.com/issuer".to_owned(), + )) + .await?; assert_eq!(client.oidc().session_tokens().unwrap(), tokens); @@ -376,37 +376,29 @@ mod tests { // This tests that refresh token works, and that it doesn't cause multiple token // refreshes whenever one spawns two refreshes around the same time. + let server = MatrixMockServer::new().await; + + let oauth_server = server.oauth(); + oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_token().ok().expect(1).named("token").mount().await; + let tmp_dir = tempfile::tempdir()?; - let client = MockClientBuilder::new("https://example.org".to_owned()) - .sqlite_store(&tmp_dir) - .unlogged() - .build() - .await; + let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await; + let oidc = client.oidc(); - let prev_tokens = prev_session_tokens(); let next_tokens = mock_session_tokens(); - let backend = Arc::new( - MockImpl::new() - .next_session_tokens(next_tokens.clone()) - .expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()), - ); - let oidc = Oidc { client: client.clone(), backend: backend.clone() }; - // Enable cross-process lock. oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?; // Restore the session. - oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?; + oidc.restore_session(mock_session(prev_session_tokens(), server.server().uri())).await?; // Immediately try to refresh the access token twice in parallel. for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await { result?; } - // There should have been at most one refresh. - assert_eq!(*backend.num_refreshes.lock().unwrap(), 1); - { // The cross process lock has been correctly updated, and the next attempt to // take it won't result in a mismatch. @@ -424,51 +416,40 @@ mod tests { #[async_test] async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> { - // Create the backend. + let server = MatrixMockServer::new().await; + let issuer = server.server().uri(); + + let oauth_server = server.oauth(); + oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_token().ok().expect(1).named("token").mount().await; + let prev_tokens = prev_session_tokens(); let next_tokens = mock_session_tokens(); - let backend = Arc::new( - MockImpl::new() - .next_session_tokens(next_tokens.clone()) - .expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()), - ); - // Create the first client. let tmp_dir = tempfile::tempdir()?; - let client = MockClientBuilder::new("https://example.org".to_owned()) - .sqlite_store(&tmp_dir) - .unlogged() - .build() - .await; + let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await; - let oidc = Oidc { client: client.clone(), backend: backend.clone() }; + let oidc = client.oidc(); oidc.enable_cross_process_refresh_lock("client1".to_owned()).await?; - oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?; + oidc.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?; // Create a second client, without restoring it, to test that a token update // before restoration doesn't cause new issues. - let unrestored_client = MockClientBuilder::new("https://example.org".to_owned()) - .sqlite_store(&tmp_dir) - .unlogged() - .build() - .await; - let unrestored_oidc = Oidc { client: unrestored_client.clone(), backend: backend.clone() }; + let unrestored_client = + server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await; + let unrestored_oidc = unrestored_client.oidc(); unrestored_oidc.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?; { // Create a third client that will run a refresh while the others two are doing // nothing. - let client3 = MockClientBuilder::new("https://example.org".to_owned()) - .sqlite_store(&tmp_dir) - .unlogged() - .build() - .await; + let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await; - let oidc3 = Oidc { client: client3.clone(), backend: backend.clone() }; + let oidc3 = client3.oidc(); oidc3.enable_cross_process_refresh_lock("client3".to_owned()).await?; - oidc3.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?; + oidc3.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?; // Run a refresh in the second client; this will invalidate the tokens from the // first token. @@ -500,7 +481,7 @@ mod tests { Box::new(|_| panic!("save_session_callback shouldn't be called here")), )?; - oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?; + oidc.restore_session(mock_session(prev_tokens.clone(), issuer)).await?; // And this client is now aware of the latest tokens. let xp_manager = @@ -550,9 +531,6 @@ mod tests { assert!(!guard.hash_mismatch); } - // There should have been at most one refresh. - assert_eq!(*backend.num_refreshes.lock().unwrap(), 1); - Ok(()) } diff --git a/crates/matrix-sdk/src/authentication/oidc/mod.rs b/crates/matrix-sdk/src/authentication/oidc/mod.rs index d1436f1bb..65cfc8af3 100644 --- a/crates/matrix-sdk/src/authentication/oidc/mod.rs +++ b/crates/matrix-sdk/src/authentication/oidc/mod.rs @@ -1358,7 +1358,7 @@ impl Oidc { let provider_metadata = match self.provider_metadata().await { Ok(metadata) => metadata, Err(err) => { - warn!("couldn't get authorization server metadata: {err}"); + warn!("couldn't get authorization server metadata: {err:?}"); fail!(refresh_status_guard, RefreshTokenError::Oidc(Arc::new(err.into()))); } }; diff --git a/crates/matrix-sdk/src/authentication/oidc/tests.rs b/crates/matrix-sdk/src/authentication/oidc/tests.rs index 5bb6dded3..85fd864bf 100644 --- a/crates/matrix-sdk/src/authentication/oidc/tests.rs +++ b/crates/matrix-sdk/src/authentication/oidc/tests.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use anyhow::Context as _; use assert_matches::assert_matches; @@ -21,15 +21,13 @@ use wiremock::{ }; use super::{ - backend::mock::{MockImpl, ISSUER_URL}, - registrations::OidcRegistrations, - AuthorizationCode, AuthorizationError, AuthorizationResponse, Oidc, OidcError, - OidcSessionTokens, RedirectUriQueryParseError, + registrations::OidcRegistrations, AuthorizationCode, AuthorizationError, AuthorizationResponse, + Oidc, OidcError, OidcSessionTokens, RedirectUriQueryParseError, }; use crate::{ test_utils::{ client::{ - oauth::{mock_client_id, mock_client_metadata, mock_session, mock_session_tokens}, + oauth::{mock_client_metadata, mock_session, mock_session_tokens}, MockClientBuilder, }, mocks::{oauth::MockServerMetadataBuilder, MatrixMockServer}, @@ -51,32 +49,22 @@ pub(crate) fn prev_session_tokens() -> OidcSessionTokens { async fn mock_environment( ) -> anyhow::Result<(Oidc, MatrixMockServer, VerifiedClientMetadata, OidcRegistrations)> { let server = MatrixMockServer::new().await; - let issuer_url = Url::parse(ISSUER_URL).unwrap(); - server.mock_who_am_i().ok().named("whoami").mount().await; + let oauth_server = server.oauth(); + oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_registration().ok().expect(1).named("registration").mount().await; + oauth_server.mock_token().ok().mount().await; + let client = server.client_builder().unlogged().build().await; - - let session_tokens = mock_session_tokens(); - - let oidc = Oidc { - client, - backend: Arc::new(MockImpl::new().mark_insecure().next_session_tokens(session_tokens)), - }; - - let client_id = mock_client_id(); let client_metadata = mock_client_metadata(); - // The mock backend doesn't support registration so set a static registration. - let mut static_registrations = HashMap::new(); - static_registrations.insert(issuer_url, client_id); - let registrations_path = tempdir().unwrap().path().join("oidc").join("registrations.json"); let registrations = - OidcRegistrations::new(®istrations_path, client_metadata.clone(), static_registrations) + OidcRegistrations::new(®istrations_path, client_metadata.clone(), HashMap::new()) .unwrap(); - Ok((oidc, server, client_metadata, registrations)) + Ok((client.oidc(), server, client_metadata, registrations)) } #[async_test] @@ -248,16 +236,14 @@ fn test_authorization_response() -> anyhow::Result<()> { #[async_test] async fn test_finish_authorization() -> anyhow::Result<()> { - let client = MockClientBuilder::new("https://example.org".to_owned()) - .registered_with_oauth(ISSUER_URL) - .build() - .await; + let server = MatrixMockServer::new().await; + let oauth_server = server.oauth(); - let session_tokens = mock_session_tokens(); - let oidc = Oidc { - client: client.clone(), - backend: Arc::new(MockImpl::new().next_session_tokens(session_tokens.clone())), - }; + oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await; + oauth_server.mock_token().ok().expect(1).named("token").mount().await; + + let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await; + let oidc = client.oidc(); // If the state is missing, then any attempt to finish authorizing will fail. let res = oidc @@ -313,7 +299,7 @@ async fn test_oidc_session() -> anyhow::Result<()> { let oidc = client.oidc(); let tokens = mock_session_tokens(); - let issuer = ISSUER_URL; + let issuer = "https://oidc.example.com/issuer"; let session = mock_session(tokens.clone(), issuer.to_owned()); oidc.restore_session(session.clone()).await?; @@ -324,14 +310,14 @@ async fn test_oidc_session() -> anyhow::Result<()> { let user_session = oidc.user_session().unwrap(); assert_eq!(user_session.meta, session.user.meta); assert_eq!(user_session.tokens, tokens); - assert_eq!(user_session.issuer, ISSUER_URL); + assert_eq!(user_session.issuer, issuer); let full_session = oidc.full_session().unwrap(); assert_eq!(full_session.client_id.0, "test_client_id"); assert_eq!(full_session.user.meta, session.user.meta); assert_eq!(full_session.user.tokens, tokens); - assert_eq!(full_session.user.issuer, ISSUER_URL); + assert_eq!(full_session.user.issuer, issuer); Ok(()) } @@ -342,13 +328,18 @@ async fn test_insecure_clients() -> anyhow::Result<()> { let server_url = server.server().uri(); server.mock_well_known().ok().expect(1).named("well_known").mount().await; + server.mock_versions().ok().expect(1..).named("versions").mount().await; + + let oauth_server = server.oauth(); + oauth_server.mock_server_metadata().ok().expect(2..).named("server_metadata").mount().await; + oauth_server.mock_token().ok().expect(2).named("token").mount().await; let prev_tokens = prev_session_tokens(); let next_tokens = mock_session_tokens(); for client in [ // Create an insecure client with the homeserver_url method. - Client::builder().homeserver_url("http://example.org").build().await?, + Client::builder().homeserver_url(&server_url).build().await?, // Create an insecure client with the insecure_server_name_no_tls method. Client::builder() .insecure_server_name_no_tls(&ServerName::parse( @@ -357,16 +348,10 @@ async fn test_insecure_clients() -> anyhow::Result<()> { .build() .await?, ] { - let backend = Arc::new( - MockImpl::new() - .mark_insecure() - .next_session_tokens(next_tokens.clone()) - .expected_refresh_token(prev_tokens.refresh_token.as_ref().unwrap().clone()), - ); - let oidc = Oidc { client: client.clone(), backend: backend.clone() }; + let oidc = client.oidc(); // Restore the previous session so we have an existing set of refresh tokens. - oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?; + oidc.restore_session(mock_session(prev_tokens.clone(), server_url.clone())).await?; let mut session_token_stream = oidc.session_tokens_stream().expect("stream available"); @@ -380,9 +365,6 @@ async fn test_insecure_clients() -> anyhow::Result<()> { }); assert_pending!(session_token_stream); - - // There should have been exactly one refresh. - assert_eq!(*backend.num_refreshes.lock().unwrap(), 1); } Ok(())