test(oidc): Use MatrixMockServer in the remaining tests

Gets rid of the MockImpl for OidcBackend.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-02-26 14:42:47 +01:00
committed by Damir Jelić
parent 81dbe2060c
commit 26cb805e0f
5 changed files with 64 additions and 316 deletions

View File

@@ -1,209 +0,0 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for that specific language governing permissions and
// limitations under the License.
//! Test implementation of the OIDC backend.
use std::sync::{Arc, Mutex};
use http::StatusCode;
use mas_oidc_client::{
error::{
DiscoveryError as OidcDiscoveryError, Error as OidcClientError, ErrorBody as OidcErrorBody,
HttpError as OidcHttpError, TokenRefreshError, TokenRequestError,
},
requests::authorization_code::AuthorizationValidationData,
types::{
client_credentials::ClientCredentials,
errors::ClientErrorCode,
iana::oauth::OAuthTokenTypeHint,
oidc::{ProviderMetadataVerificationError, VerifiedProviderMetadata},
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
},
};
use url::Url;
use super::{OidcBackend, OidcError, RefreshedSessionTokens};
use crate::{
authentication::oidc::{AuthorizationCode, OauthDiscoveryError, OidcSessionTokens},
test_utils::mocks::oauth::MockServerMetadataBuilder,
};
pub(crate) const ISSUER_URL: &str = "https://oidc.example.com/issuer";
pub(crate) const CLIENT_ID: &str = "test_client_id";
#[derive(Debug)]
pub(crate) struct MockImpl {
/// Must be an HTTPS URL.
issuer: String,
/// The next session tokens that will be returned by a login or refresh.
next_session_tokens: Option<OidcSessionTokens>,
/// The next refresh token that's expected for a refresh.
expected_refresh_token: Option<String>,
/// Number of refreshes that effectively happened.
pub num_refreshes: Arc<Mutex<u32>>,
/// Tokens that have been revoked with `revoke_token`.
pub revoked_tokens: Arc<Mutex<Vec<String>>>,
/// Should we only accept insecure flags during discovery?
is_insecure: bool,
}
impl MockImpl {
pub fn new() -> Self {
Self {
issuer: ISSUER_URL.to_owned(),
next_session_tokens: None,
expected_refresh_token: None,
num_refreshes: Default::default(),
revoked_tokens: Default::default(),
is_insecure: false,
}
}
pub fn next_session_tokens(mut self, next_session_tokens: OidcSessionTokens) -> Self {
self.next_session_tokens = Some(next_session_tokens);
self
}
pub fn expected_refresh_token(mut self, refresh_token: String) -> Self {
self.expected_refresh_token = Some(refresh_token);
self
}
pub fn mark_insecure(mut self) -> Self {
self.is_insecure = true;
self
}
}
#[async_trait::async_trait]
impl OidcBackend for MockImpl {
async fn discover(
&self,
insecure: bool,
) -> Result<VerifiedProviderMetadata, OauthDiscoveryError> {
if insecure != self.is_insecure {
return Err(OidcDiscoveryError::Validation(
ProviderMetadataVerificationError::UrlNonHttpsScheme(
"mocking backend",
Url::parse(&self.issuer).unwrap(),
),
)
.into());
}
Ok(MockServerMetadataBuilder::new(&self.issuer)
.build()
.validate(&self.issuer)
.expect("server metadata should pass validation"))
}
async fn trade_authorization_code_for_tokens(
&self,
_provider_metadata: VerifiedProviderMetadata,
_credentials: ClientCredentials,
_auth_code: AuthorizationCode,
_validation_data: AuthorizationValidationData,
) -> Result<OidcSessionTokens, OidcError> {
Ok(self
.next_session_tokens
.as_ref()
.expect("missing next session tokens in testing")
.clone())
}
async fn register_client(
&self,
_registration_endpoint: &Url,
_client_metadata: VerifiedClientMetadata,
_software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError> {
Ok(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
})
}
async fn revoke_token(
&self,
_client_credentials: ClientCredentials,
_revocation_endpoint: &Url,
token: String,
_token_type_hint: Option<OAuthTokenTypeHint>,
) -> Result<(), OidcError> {
self.revoked_tokens.lock().unwrap().push(token);
Ok(())
}
async fn refresh_access_token(
&self,
_provider_metadata: VerifiedProviderMetadata,
_credentials: ClientCredentials,
refresh_token: String,
) -> Result<RefreshedSessionTokens, OidcError> {
if Some(refresh_token) != self.expected_refresh_token {
Err(OidcError::Oidc(OidcClientError::TokenRefresh(TokenRefreshError::Token(
TokenRequestError::Http(OidcHttpError {
body: Some(OidcErrorBody {
error: ClientErrorCode::InvalidGrant,
error_description: None,
}),
status: StatusCode::from_u16(400).unwrap(),
}),
))))
} else {
*self.num_refreshes.lock().unwrap() += 1;
let next_tokens = self.next_session_tokens.clone().expect("missing next tokens");
Ok(RefreshedSessionTokens {
access_token: next_tokens.access_token,
refresh_token: next_tokens.refresh_token,
})
}
}
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn request_device_authorization(
&self,
_device_authorization_endpoint: Url,
_client_id: oauth2::ClientId,
_scopes: Vec<oauth2::Scope>,
) -> Result<
oauth2::StandardDeviceAuthorizationResponse,
oauth2::basic::BasicRequestTokenError<oauth2::HttpClientError<reqwest::Error>>,
> {
unimplemented!()
}
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn exchange_device_code(
&self,
_token_endpoint: Url,
_client_id: oauth2::ClientId,
_device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse,
) -> Result<
OidcSessionTokens,
oauth2::RequestTokenError<
oauth2::HttpClientError<reqwest::Error>,
oauth2::DeviceCodeErrorResponse,
>,
> {
unimplemented!()
}
}

View File

@@ -31,9 +31,6 @@ use super::{AuthorizationCode, OauthDiscoveryError, OidcError, OidcSessionTokens
pub(crate) mod server;
#[cfg(test)]
pub(crate) mod mock;
pub(super) struct RefreshedSessionTokens {
pub access_token: String,
pub refresh_token: Option<String>,

View File

@@ -251,7 +251,6 @@ pub enum CrossProcessRefreshLockError {
#[cfg(all(test, feature = "e2e-encryption"))]
mod tests {
use std::sync::Arc;
use anyhow::Context as _;
use futures_util::future::join_all;
@@ -261,12 +260,7 @@ mod tests {
use super::compute_session_hash;
use crate::{
authentication::oidc::{
backend::mock::{MockImpl, ISSUER_URL},
cross_process::SessionHash,
tests::prev_session_tokens,
Oidc,
},
authentication::oidc::{cross_process::SessionHash, tests::prev_session_tokens},
test_utils::{
client::{
oauth::{mock_session, mock_session_tokens},
@@ -302,7 +296,13 @@ mod tests {
)?;
let session_hash = compute_session_hash(&tokens);
client.oidc().restore_session(mock_session(tokens.clone(), ISSUER_URL.to_owned())).await?;
client
.oidc()
.restore_session(mock_session(
tokens.clone(),
"https://oidc.example.com/issuer".to_owned(),
))
.await?;
assert_eq!(client.oidc().session_tokens().unwrap(), tokens);
@@ -376,37 +376,29 @@ mod tests {
// This tests that refresh token works, and that it doesn't cause multiple token
// refreshes whenever one spawns two refreshes around the same time.
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_token().ok().expect(1).named("token").mount().await;
let tmp_dir = tempfile::tempdir()?;
let client = MockClientBuilder::new("https://example.org".to_owned())
.sqlite_store(&tmp_dir)
.unlogged()
.build()
.await;
let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
let oidc = client.oidc();
let prev_tokens = prev_session_tokens();
let next_tokens = mock_session_tokens();
let backend = Arc::new(
MockImpl::new()
.next_session_tokens(next_tokens.clone())
.expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()),
);
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
// Enable cross-process lock.
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
// Restore the session.
oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?;
oidc.restore_session(mock_session(prev_session_tokens(), server.server().uri())).await?;
// Immediately try to refresh the access token twice in parallel.
for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await {
result?;
}
// There should have been at most one refresh.
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
{
// The cross process lock has been correctly updated, and the next attempt to
// take it won't result in a mismatch.
@@ -424,51 +416,40 @@ mod tests {
#[async_test]
async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
// Create the backend.
let server = MatrixMockServer::new().await;
let issuer = server.server().uri();
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_token().ok().expect(1).named("token").mount().await;
let prev_tokens = prev_session_tokens();
let next_tokens = mock_session_tokens();
let backend = Arc::new(
MockImpl::new()
.next_session_tokens(next_tokens.clone())
.expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()),
);
// Create the first client.
let tmp_dir = tempfile::tempdir()?;
let client = MockClientBuilder::new("https://example.org".to_owned())
.sqlite_store(&tmp_dir)
.unlogged()
.build()
.await;
let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
let oidc = client.oidc();
oidc.enable_cross_process_refresh_lock("client1".to_owned()).await?;
oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?;
oidc.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?;
// Create a second client, without restoring it, to test that a token update
// before restoration doesn't cause new issues.
let unrestored_client = MockClientBuilder::new("https://example.org".to_owned())
.sqlite_store(&tmp_dir)
.unlogged()
.build()
.await;
let unrestored_oidc = Oidc { client: unrestored_client.clone(), backend: backend.clone() };
let unrestored_client =
server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
let unrestored_oidc = unrestored_client.oidc();
unrestored_oidc.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
{
// Create a third client that will run a refresh while the others two are doing
// nothing.
let client3 = MockClientBuilder::new("https://example.org".to_owned())
.sqlite_store(&tmp_dir)
.unlogged()
.build()
.await;
let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
let oidc3 = Oidc { client: client3.clone(), backend: backend.clone() };
let oidc3 = client3.oidc();
oidc3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
oidc3.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?;
oidc3.restore_session(mock_session(prev_tokens.clone(), issuer.clone())).await?;
// Run a refresh in the second client; this will invalidate the tokens from the
// first token.
@@ -500,7 +481,7 @@ mod tests {
Box::new(|_| panic!("save_session_callback shouldn't be called here")),
)?;
oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?;
oidc.restore_session(mock_session(prev_tokens.clone(), issuer)).await?;
// And this client is now aware of the latest tokens.
let xp_manager =
@@ -550,9 +531,6 @@ mod tests {
assert!(!guard.hash_mismatch);
}
// There should have been at most one refresh.
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
Ok(())
}

View File

@@ -1358,7 +1358,7 @@ impl Oidc {
let provider_metadata = match self.provider_metadata().await {
Ok(metadata) => metadata,
Err(err) => {
warn!("couldn't get authorization server metadata: {err}");
warn!("couldn't get authorization server metadata: {err:?}");
fail!(refresh_status_guard, RefreshTokenError::Oidc(Arc::new(err.into())));
}
};

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use anyhow::Context as _;
use assert_matches::assert_matches;
@@ -21,15 +21,13 @@ use wiremock::{
};
use super::{
backend::mock::{MockImpl, ISSUER_URL},
registrations::OidcRegistrations,
AuthorizationCode, AuthorizationError, AuthorizationResponse, Oidc, OidcError,
OidcSessionTokens, RedirectUriQueryParseError,
registrations::OidcRegistrations, AuthorizationCode, AuthorizationError, AuthorizationResponse,
Oidc, OidcError, OidcSessionTokens, RedirectUriQueryParseError,
};
use crate::{
test_utils::{
client::{
oauth::{mock_client_id, mock_client_metadata, mock_session, mock_session_tokens},
oauth::{mock_client_metadata, mock_session, mock_session_tokens},
MockClientBuilder,
},
mocks::{oauth::MockServerMetadataBuilder, MatrixMockServer},
@@ -51,32 +49,22 @@ pub(crate) fn prev_session_tokens() -> OidcSessionTokens {
async fn mock_environment(
) -> anyhow::Result<(Oidc, MatrixMockServer, VerifiedClientMetadata, OidcRegistrations)> {
let server = MatrixMockServer::new().await;
let issuer_url = Url::parse(ISSUER_URL).unwrap();
server.mock_who_am_i().ok().named("whoami").mount().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
oauth_server.mock_token().ok().mount().await;
let client = server.client_builder().unlogged().build().await;
let session_tokens = mock_session_tokens();
let oidc = Oidc {
client,
backend: Arc::new(MockImpl::new().mark_insecure().next_session_tokens(session_tokens)),
};
let client_id = mock_client_id();
let client_metadata = mock_client_metadata();
// The mock backend doesn't support registration so set a static registration.
let mut static_registrations = HashMap::new();
static_registrations.insert(issuer_url, client_id);
let registrations_path = tempdir().unwrap().path().join("oidc").join("registrations.json");
let registrations =
OidcRegistrations::new(&registrations_path, client_metadata.clone(), static_registrations)
OidcRegistrations::new(&registrations_path, client_metadata.clone(), HashMap::new())
.unwrap();
Ok((oidc, server, client_metadata, registrations))
Ok((client.oidc(), server, client_metadata, registrations))
}
#[async_test]
@@ -248,16 +236,14 @@ fn test_authorization_response() -> anyhow::Result<()> {
#[async_test]
async fn test_finish_authorization() -> anyhow::Result<()> {
let client = MockClientBuilder::new("https://example.org".to_owned())
.registered_with_oauth(ISSUER_URL)
.build()
.await;
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
let session_tokens = mock_session_tokens();
let oidc = Oidc {
client: client.clone(),
backend: Arc::new(MockImpl::new().next_session_tokens(session_tokens.clone())),
};
oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
oauth_server.mock_token().ok().expect(1).named("token").mount().await;
let client = server.client_builder().registered_with_oauth(server.server().uri()).build().await;
let oidc = client.oidc();
// If the state is missing, then any attempt to finish authorizing will fail.
let res = oidc
@@ -313,7 +299,7 @@ async fn test_oidc_session() -> anyhow::Result<()> {
let oidc = client.oidc();
let tokens = mock_session_tokens();
let issuer = ISSUER_URL;
let issuer = "https://oidc.example.com/issuer";
let session = mock_session(tokens.clone(), issuer.to_owned());
oidc.restore_session(session.clone()).await?;
@@ -324,14 +310,14 @@ async fn test_oidc_session() -> anyhow::Result<()> {
let user_session = oidc.user_session().unwrap();
assert_eq!(user_session.meta, session.user.meta);
assert_eq!(user_session.tokens, tokens);
assert_eq!(user_session.issuer, ISSUER_URL);
assert_eq!(user_session.issuer, issuer);
let full_session = oidc.full_session().unwrap();
assert_eq!(full_session.client_id.0, "test_client_id");
assert_eq!(full_session.user.meta, session.user.meta);
assert_eq!(full_session.user.tokens, tokens);
assert_eq!(full_session.user.issuer, ISSUER_URL);
assert_eq!(full_session.user.issuer, issuer);
Ok(())
}
@@ -342,13 +328,18 @@ async fn test_insecure_clients() -> anyhow::Result<()> {
let server_url = server.server().uri();
server.mock_well_known().ok().expect(1).named("well_known").mount().await;
server.mock_versions().ok().expect(1..).named("versions").mount().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(2..).named("server_metadata").mount().await;
oauth_server.mock_token().ok().expect(2).named("token").mount().await;
let prev_tokens = prev_session_tokens();
let next_tokens = mock_session_tokens();
for client in [
// Create an insecure client with the homeserver_url method.
Client::builder().homeserver_url("http://example.org").build().await?,
Client::builder().homeserver_url(&server_url).build().await?,
// Create an insecure client with the insecure_server_name_no_tls method.
Client::builder()
.insecure_server_name_no_tls(&ServerName::parse(
@@ -357,16 +348,10 @@ async fn test_insecure_clients() -> anyhow::Result<()> {
.build()
.await?,
] {
let backend = Arc::new(
MockImpl::new()
.mark_insecure()
.next_session_tokens(next_tokens.clone())
.expected_refresh_token(prev_tokens.refresh_token.as_ref().unwrap().clone()),
);
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
let oidc = client.oidc();
// Restore the previous session so we have an existing set of refresh tokens.
oidc.restore_session(mock_session(prev_tokens.clone(), ISSUER_URL.to_owned())).await?;
oidc.restore_session(mock_session(prev_tokens.clone(), server_url.clone())).await?;
let mut session_token_stream = oidc.session_tokens_stream().expect("stream available");
@@ -380,9 +365,6 @@ async fn test_insecure_clients() -> anyhow::Result<()> {
});
assert_pending!(session_token_stream);
// There should have been exactly one refresh.
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
}
Ok(())