Add tests for OIDC (#2558)

* chore(oidc): put impl of OIDC server behind a trait and add tests for the OIDC flow

chore(oidc): add first tests for login/AuthorizationResponse

chore(oidc): add tests for finish_authorization/finish_login/refresh_access_token

chore(oidc): add test for logout

chore(🤷): clippy/fmt

* chore(oidc): move account management test to the new `tests` module

* chore(oidc): address review comments
This commit is contained in:
Benjamin Bouvier
2023-09-18 14:33:35 +02:00
committed by GitHub
parent f44ebf1bf9
commit 5ff83c00c5
9 changed files with 1061 additions and 269 deletions

1
Cargo.lock generated
View File

@@ -3073,6 +3073,7 @@ dependencies = [
"serde",
"serde_html_form",
"serde_json",
"serde_urlencoded",
"tempfile",
"thiserror",
"tokio",

View File

@@ -144,6 +144,7 @@ dirs = "5.0.1"
futures-executor = { workspace = true }
matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base", default_features = false, features = ["testing"] }
matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test" }
serde_urlencoded = "0.7.1"
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]

View File

@@ -14,12 +14,9 @@
use std::{collections::HashSet, num::NonZeroU32};
use chrono::Utc;
use language_tags::LanguageTag;
use mas_oidc_client::{
requests::authorization_code::{
build_authorization_url, build_par_authorization_url, AuthorizationRequestData,
},
requests::authorization_code::{build_authorization_url, AuthorizationRequestData},
types::{
requests::{Display, Prompt},
scope::Scope,
@@ -187,16 +184,15 @@ impl OidcAuthCodeUrlBuilder {
let client_credentials =
oidc.client_credentials().ok_or(OidcError::NotAuthenticated)?;
let res = build_par_authorization_url(
&oidc.http_service(),
client_credentials.clone(),
par_endpoint,
authorization_endpoint.clone(),
authorization_data.clone(),
Utc::now(),
&mut rng,
)
.await;
let res = oidc
.backend
.build_par_authorization_url(
client_credentials.clone(),
par_endpoint,
authorization_endpoint.clone(),
authorization_data.clone(),
)
.await;
match res {
Ok(res) => res,
@@ -209,8 +205,8 @@ impl OidcAuthCodeUrlBuilder {
// If the client said that PAR should be enforced, we should not try without
// it, so just return the error.
if client_metadata.require_pushed_authorization_requests.unwrap_or_default() {
return Err(error.into());
if client_metadata.require_pushed_authorization_requests.unwrap_or(false) {
return Err(error);
}
error!(

View File

@@ -0,0 +1,192 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for that specific language governing permissions and
// limitations under the License.
//! Test implementation of the OIDC backend.
use std::sync::{Arc, Mutex};
use http::StatusCode;
use mas_oidc_client::{
error::{
DiscoveryError, Error as OidcClientError, ErrorBody as OidcErrorBody,
HttpError as OidcHttpError, TokenRefreshError, TokenRequestError,
},
requests::authorization_code::{AuthorizationRequestData, AuthorizationValidationData},
types::{
client_credentials::ClientCredentials,
errors::ClientErrorCode,
iana::oauth::OAuthTokenTypeHint,
oidc::{ProviderMetadata, VerifiedProviderMetadata},
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
IdToken,
},
};
use url::Url;
use super::{OidcBackend, OidcError, RefreshedSessionTokens};
use crate::oidc::{AuthorizationCode, OidcSessionTokens};
pub(crate) const ISSUER_URL: &str = "https://oidc.example.com/issuer";
pub(crate) const AUTHORIZATION_URL: &str = "https://oidc.example.com/authorization";
pub(crate) const REVOCATION_URL: &str = "https://oidc.example.com/revocation";
pub(crate) const TOKEN_URL: &str = "https://oidc.example.com/token";
pub(crate) const JWKS_URL: &str = "https://oidc.example.com/jwks";
#[derive(Debug)]
pub(crate) struct MockImpl {
/// Must be an HTTPS URL.
issuer: String,
/// Must be an HTTPS URL.
authorization_endpoint: String,
/// Must be an HTTPS URL.
token_endpoint: String,
/// Must be an HTTPS URL.
jwks_uri: String,
/// Must be an HTTPS URL.
revocation_endpoint: String,
/// The next session tokens that will be returned by a login or refresh.
next_session_tokens: Option<OidcSessionTokens>,
/// The next refresh token that's expected for a refresh.
expected_refresh_token: Option<String>,
/// Number of refreshes that effectively happened.
pub num_refreshes: Arc<Mutex<u32>>,
/// Tokens that have been revoked with `revoke_token`.
pub revoked_tokens: Arc<Mutex<Vec<String>>>,
}
impl MockImpl {
pub fn new() -> Self {
Self {
issuer: ISSUER_URL.to_owned(),
authorization_endpoint: AUTHORIZATION_URL.to_owned(),
token_endpoint: TOKEN_URL.to_owned(),
jwks_uri: JWKS_URL.to_owned(),
revocation_endpoint: REVOCATION_URL.to_owned(),
next_session_tokens: None,
expected_refresh_token: None,
num_refreshes: Default::default(),
revoked_tokens: Default::default(),
}
}
pub fn next_session_tokens(mut self, next_session_tokens: OidcSessionTokens) -> Self {
self.next_session_tokens = Some(next_session_tokens);
self
}
pub fn expected_refresh_token(mut self, refresh_token: String) -> Self {
self.expected_refresh_token = Some(refresh_token);
self
}
}
#[async_trait::async_trait]
impl OidcBackend for MockImpl {
async fn discover(&self, issuer: &str) -> Result<VerifiedProviderMetadata, OidcError> {
Ok(ProviderMetadata {
issuer: Some(self.issuer.clone()),
authorization_endpoint: Some(Url::parse(&self.authorization_endpoint).unwrap()),
revocation_endpoint: Some(Url::parse(&self.revocation_endpoint).unwrap()),
token_endpoint: Some(Url::parse(&self.token_endpoint).unwrap()),
jwks_uri: Some(Url::parse(&self.jwks_uri).unwrap()),
response_types_supported: Some(vec![]),
subject_types_supported: Some(vec![]),
id_token_signing_alg_values_supported: Some(vec![]),
..Default::default()
}
.validate(issuer)
.map_err(DiscoveryError::from)?)
}
async fn trade_authorization_code_for_tokens(
&self,
_provider_metadata: VerifiedProviderMetadata,
_credentials: ClientCredentials,
_metadata: VerifiedClientMetadata,
_auth_code: AuthorizationCode,
_validation_data: AuthorizationValidationData,
) -> Result<OidcSessionTokens, OidcError> {
Ok(self
.next_session_tokens
.as_ref()
.expect("missing next session tokens in testing")
.clone())
}
async fn register_client(
&self,
_registration_endpoint: &Url,
_client_metadata: VerifiedClientMetadata,
_software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError> {
todo!()
}
async fn build_par_authorization_url(
&self,
_client_credentials: ClientCredentials,
_par_endpoint: &Url,
_authorization_endpoint: Url,
_authorization_data: AuthorizationRequestData,
) -> Result<(Url, AuthorizationValidationData), OidcError> {
todo!()
}
async fn revoke_token(
&self,
_client_credentials: ClientCredentials,
_revocation_endpoint: &Url,
token: String,
_token_type_hint: Option<OAuthTokenTypeHint>,
) -> Result<(), OidcError> {
self.revoked_tokens.lock().unwrap().push(token);
Ok(())
}
async fn refresh_access_token(
&self,
_provider_metadata: VerifiedProviderMetadata,
_credentials: ClientCredentials,
_metadata: &VerifiedClientMetadata,
refresh_token: String,
_latest_id_token: Option<IdToken<'static>>,
) -> Result<RefreshedSessionTokens, OidcError> {
if Some(refresh_token) != self.expected_refresh_token {
Err(OidcError::Oidc(OidcClientError::TokenRefresh(TokenRefreshError::Token(
TokenRequestError::Http(OidcHttpError {
body: Some(OidcErrorBody {
error: ClientErrorCode::InvalidGrant,
error_description: None,
}),
status: StatusCode::from_u16(400).unwrap(),
}),
))))
} else {
*self.num_refreshes.lock().unwrap() += 1;
let next_tokens = self.next_session_tokens.clone().expect("missing next tokens");
Ok(RefreshedSessionTokens {
access_token: next_tokens.access_token,
refresh_token: next_tokens.refresh_token,
})
}
}
}

View File

@@ -0,0 +1,87 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for that specific language governing permissions and
// limitations under the License.
//! Trait for defining an implementation for an OIDC backend.
//!
//! Used mostly for testing purposes.
use mas_oidc_client::{
requests::authorization_code::{AuthorizationRequestData, AuthorizationValidationData},
types::{
client_credentials::ClientCredentials,
iana::oauth::OAuthTokenTypeHint,
oidc::VerifiedProviderMetadata,
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
IdToken,
},
};
use url::Url;
use super::{AuthorizationCode, OidcError, OidcSessionTokens};
pub(crate) mod server;
#[cfg(test)]
pub(crate) mod mock;
pub(super) struct RefreshedSessionTokens {
pub access_token: String,
pub refresh_token: Option<String>,
}
#[async_trait::async_trait]
pub(super) trait OidcBackend: std::fmt::Debug + Send + Sync {
async fn discover(&self, issuer: &str) -> Result<VerifiedProviderMetadata, OidcError>;
async fn register_client(
&self,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError>;
async fn trade_authorization_code_for_tokens(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
metadata: VerifiedClientMetadata,
auth_code: AuthorizationCode,
validation_data: AuthorizationValidationData,
) -> Result<OidcSessionTokens, OidcError>;
async fn refresh_access_token(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
metadata: &VerifiedClientMetadata,
refresh_token: String,
latest_id_token: Option<IdToken<'static>>,
) -> Result<RefreshedSessionTokens, OidcError>;
async fn build_par_authorization_url(
&self,
client_credentials: ClientCredentials,
par_endpoint: &Url,
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData,
) -> Result<(Url, AuthorizationValidationData), OidcError>;
async fn revoke_token(
&self,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
) -> Result<(), OidcError>;
}

View File

@@ -0,0 +1,203 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for that specific language governing permissions and
// limitations under the License.
//! Actual implementation of the OIDC backend, using the mas_oidc_client
//! implementation.
use chrono::Utc;
use mas_oidc_client::{
http_service::HttpService,
jose::jwk::PublicJsonWebKeySet,
requests::{
authorization_code::{
access_token_with_authorization_code, build_par_authorization_url,
AuthorizationRequestData, AuthorizationValidationData,
},
discovery::discover,
jose::{fetch_jwks, JwtVerificationData},
refresh_token::refresh_access_token,
registration::register_client,
revocation::revoke_token,
},
types::{
client_credentials::ClientCredentials,
iana::oauth::OAuthTokenTypeHint,
oidc::VerifiedProviderMetadata,
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
IdToken,
},
};
use url::Url;
use super::{OidcBackend, OidcError, RefreshedSessionTokens};
use crate::{
oidc::{rng, AuthorizationCode, OidcSessionTokens},
Client,
};
#[derive(Debug)]
pub(crate) struct OidcServer {
client: Client,
}
impl OidcServer {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
fn http_service(&self) -> HttpService {
HttpService::new(self.client.inner.http_client.clone())
}
/// Fetch the OpenID Connect JSON Web Key Set at the given URI.
///
/// Returns an error if the client registration was not restored, or if an
/// error occurred when fetching the data.
async fn fetch_jwks(&self, uri: &Url) -> Result<PublicJsonWebKeySet, OidcError> {
fetch_jwks(&self.http_service(), uri).await.map_err(Into::into)
}
}
#[async_trait::async_trait]
impl OidcBackend for OidcServer {
async fn discover(&self, issuer: &str) -> Result<VerifiedProviderMetadata, OidcError> {
discover(&self.http_service(), issuer).await.map_err(Into::into)
}
async fn trade_authorization_code_for_tokens(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
metadata: VerifiedClientMetadata,
auth_code: AuthorizationCode,
validation_data: AuthorizationValidationData,
) -> Result<OidcSessionTokens, OidcError> {
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
let id_token_verification_data = JwtVerificationData {
issuer: provider_metadata.issuer(),
jwks: &jwks,
client_id: &credentials.client_id().to_owned(),
signing_algorithm: metadata.id_token_signed_response_alg(),
};
let (response, id_token) = access_token_with_authorization_code(
&self.http_service(),
credentials.clone(),
provider_metadata.token_endpoint(),
auth_code.code,
validation_data,
Some(id_token_verification_data),
Utc::now(),
&mut rng()?,
)
.await?;
Ok(OidcSessionTokens {
access_token: response.access_token,
refresh_token: response.refresh_token,
latest_id_token: id_token,
})
}
async fn refresh_access_token(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
metadata: &VerifiedClientMetadata,
refresh_token: String,
latest_id_token: Option<IdToken<'static>>,
) -> Result<RefreshedSessionTokens, OidcError> {
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
let id_token_verification_data = JwtVerificationData {
issuer: provider_metadata.issuer(),
jwks: &jwks,
client_id: &credentials.client_id().to_owned(),
signing_algorithm: &metadata.id_token_signed_response_alg().clone(),
};
refresh_access_token(
&self.http_service(),
credentials,
provider_metadata.token_endpoint(),
refresh_token,
None,
Some(id_token_verification_data),
latest_id_token.as_ref(),
Utc::now(),
&mut rng()?,
)
.await
.map(|(response, _id_token)| RefreshedSessionTokens {
access_token: response.access_token,
refresh_token: response.refresh_token,
})
.map_err(Into::into)
}
async fn register_client(
&self,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError> {
register_client(
&self.http_service(),
registration_endpoint,
client_metadata,
software_statement,
)
.await
.map_err(Into::into)
}
async fn build_par_authorization_url(
&self,
client_credentials: ClientCredentials,
par_endpoint: &Url,
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData,
) -> Result<(Url, AuthorizationValidationData), OidcError> {
Ok(build_par_authorization_url(
&self.http_service(),
client_credentials,
par_endpoint,
authorization_endpoint,
authorization_data,
Utc::now(),
&mut rng()?,
)
.await?)
}
async fn revoke_token(
&self,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
) -> Result<(), OidcError> {
Ok(revoke_token(
&self.http_service(),
client_credentials,
revocation_endpoint,
token,
token_type_hint,
Utc::now(),
&mut rng()?,
)
.await?)
}
}

View File

@@ -241,47 +241,37 @@ pub enum CrossProcessRefreshLockError {
DuplicatedLock,
}
#[cfg(test)]
#[cfg(all(test, feature = "e2e-encryption"))]
mod tests {
use mas_oidc_client::types::{
client_credentials::ClientCredentials, iana::oauth::OAuthClientAuthenticationMethod,
registration::ClientMetadata,
};
use std::sync::Arc;
use anyhow::Context as _;
use futures_util::future::join_all;
use matrix_sdk_base::SessionMeta;
use matrix_sdk_test::async_test;
use ruma::api::client::discovery::discover_homeserver::AuthenticationServerInfo;
use ruma::{
api::client::discovery::discover_homeserver::AuthenticationServerInfo, owned_device_id,
owned_user_id,
};
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::compute_session_hash;
use crate::{
oidc::{OidcSession, OidcSessionTokens, UserSession},
oidc::{
backend::mock::{MockImpl, ISSUER_URL},
tests,
tests::mock_registered_client_data,
Oidc, OidcSessionTokens,
},
test_utils::test_client_builder,
Error,
};
fn fake_session(tokens: OidcSessionTokens) -> OidcSession {
OidcSession {
credentials: ClientCredentials::None { client_id: "test_client_id".to_owned() },
metadata: ClientMetadata {
redirect_uris: Some(vec![]), // empty vector is ok lol
token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None),
..ClientMetadata::default()
}
.validate()
.expect("validate client metadata"),
user: UserSession {
meta: SessionMeta {
user_id: ruma::user_id!("@u:e.uk").to_owned(),
device_id: ruma::device_id!("XYZ").to_owned(),
},
tokens,
issuer_info: AuthenticationServerInfo::new("issuer".to_owned(), None),
},
}
}
#[cfg(feature = "e2e-encryption")]
#[async_test]
async fn test_oidc_restore_session_lock() -> Result<(), Error> {
async fn test_restore_session_lock() -> Result<(), Error> {
// Create a client that will use sqlite databases.
let tmp_dir = tempfile::tempdir()?;
@@ -305,7 +295,7 @@ mod tests {
)?;
let session_hash = compute_session_hash(&tokens);
client.oidc().restore_session(fake_session(tokens.clone())).await?;
client.oidc().restore_session(tests::mock_session(tokens.clone())).await?;
assert_eq!(client.oidc().session_tokens().unwrap(), tokens);
@@ -325,4 +315,173 @@ mod tests {
Ok(())
}
#[async_test]
async fn test_finish_login() -> anyhow::Result<()> {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/_matrix/client/r0/account/whoami"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"user_id": "@joe:example.org",
"device_id": "D3V1C31D",
})))
.expect(1)
.named("`GET /whoami` good token")
.mount(&server)
.await;
let tmp_dir = tempfile::tempdir()?;
let client =
test_client_builder(Some(server.uri())).sqlite_store(tmp_dir, None).build().await?;
let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) };
// Restore registered client.
let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None);
let (client_credentials, client_metadata) = mock_registered_client_data();
oidc.restore_registered_client(issuer_info, client_metadata, client_credentials);
// Enable cross-process lock.
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
// Simulate we've done finalize_authorization / restore_session before.
let session_tokens = OidcSessionTokens {
access_token: "access".to_owned(),
refresh_token: Some("refresh".to_owned()),
latest_id_token: None,
};
oidc.set_session_tokens(session_tokens.clone());
// Now, finishing logging will get the user and device ids.
oidc.finish_login().await?;
let session_meta = client.session_meta().context("should have session meta now")?;
assert_eq!(
*session_meta,
SessionMeta {
user_id: owned_user_id!("@joe:example.org"),
device_id: owned_device_id!("D3V1C31D")
}
);
{
// The cross process lock has been correctly updated, and the next attempt to
// take it won't result in a mismatch.
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let actual_hash = compute_session_hash(&session_tokens);
assert_eq!(guard.db_hash, Some(actual_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
assert!(!guard.hash_mismatch);
}
Ok(())
}
#[async_test]
async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
// This tests that refresh token works, and that it doesn't cause multiple token
// refreshes whenever one spawns two refreshes around the same time.
let tmp_dir = tempfile::tempdir()?;
let client = test_client_builder(None).sqlite_store(tmp_dir, None).build().await?;
let prev_tokens = OidcSessionTokens {
access_token: "prev-access-token".to_owned(),
refresh_token: Some("prev-refresh-token".to_owned()),
latest_id_token: None,
};
let next_tokens = OidcSessionTokens {
access_token: "next-access-token".to_owned(),
refresh_token: Some("next-refresh-token".to_owned()),
latest_id_token: None,
};
let backend = Arc::new(
MockImpl::new()
.next_session_tokens(next_tokens.clone())
.expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()),
);
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
// Enable cross-process lock.
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
// Restore the session.
oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
// Immediately try to refresh the access token twice in parallel.
for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await {
result?;
}
// There should have been at most one refresh.
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
{
// The cross process lock has been correctly updated, and the next attempt to
// take it won't result in a mismatch.
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
let actual_hash = compute_session_hash(&next_tokens);
assert_eq!(guard.db_hash, Some(actual_hash));
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
assert!(!guard.hash_mismatch);
}
Ok(())
}
#[async_test]
async fn test_logout() -> anyhow::Result<()> {
let tmp_dir = tempfile::tempdir()?;
let client = test_client_builder(None).sqlite_store(tmp_dir, None).build().await?;
let tokens = OidcSessionTokens {
access_token: "prev-access-token".to_owned(),
refresh_token: Some("prev-refresh-token".to_owned()),
latest_id_token: None,
};
let backend = Arc::new(MockImpl::new());
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
// Enable cross-process lock.
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
// Restore the session.
oidc.restore_session(tests::mock_session(tokens.clone())).await?;
let end_session_builder = oidc.logout().await?;
// No end session builder because our test impl doesn't provide an end session
// endpoint.
assert!(end_session_builder.is_none());
// Both the access token and the refresh tokens have been invalidated.
{
let revoked = backend.revoked_tokens.lock().unwrap();
assert_eq!(revoked.len(), 2);
assert_eq!(
*revoked,
vec![tokens.access_token.clone(), tokens.refresh_token.clone().unwrap(),]
);
}
{
// The cross process lock has been correctly updated, and all the hashes are
// empty after a logout.
let xp_manager =
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
let guard = xp_manager.spin_lock().await?;
assert!(guard.db_hash.is_none());
assert!(guard.hash_guard.is_none());
assert!(!guard.hash_mismatch);
}
Ok(())
}
}

View File

@@ -173,28 +173,17 @@ use std::{
};
use as_variant::as_variant;
use chrono::Utc;
use eyeball::SharedObservable;
use futures_core::Stream;
pub use mas_oidc_client::{error, types};
use mas_oidc_client::{
http_service::HttpService,
jose::jwk::PublicJsonWebKeySet,
requests::{
authorization_code::{access_token_with_authorization_code, AuthorizationValidationData},
discovery::discover,
jose::{fetch_jwks, JwtVerificationData},
refresh_token::refresh_access_token,
registration::register_client,
revocation::revoke_token,
},
requests::authorization_code::AuthorizationValidationData,
types::{
client_credentials::ClientCredentials,
errors::ClientError,
iana::oauth::OAuthTokenTypeHint,
oidc::VerifiedProviderMetadata,
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
requests::AccessTokenResponse,
scope::{MatrixApiScopeToken, Scope, ScopeToken},
IdToken,
},
@@ -209,17 +198,23 @@ use tracing::{error, trace, warn};
use url::Url;
mod auth_code_builder;
mod backend;
mod cross_process;
mod data_serde;
mod end_session_builder;
#[cfg(test)]
mod tests;
use self::cross_process::{
CrossProcessRefreshLockError, CrossProcessRefreshLockGuard, CrossProcessRefreshManager,
};
pub use self::{
auth_code_builder::{OidcAuthCodeUrlBuilder, OidcAuthorizationData},
end_session_builder::{OidcEndSessionData, OidcEndSessionUrlBuilder},
};
use self::{
backend::{server::OidcServer, OidcBackend},
cross_process::{
CrossProcessRefreshLockError, CrossProcessRefreshLockGuard, CrossProcessRefreshManager,
},
};
use crate::{authentication::AuthData, client::SessionChange, Client, RefreshTokenError, Result};
pub(crate) struct OidcCtx {
@@ -268,11 +263,14 @@ impl fmt::Debug for OidcAuthData {
pub struct Oidc {
/// The underlying Matrix API client.
client: Client,
/// The implementation of the OIDC backend.
backend: Arc<dyn OidcBackend>,
}
impl Oidc {
pub(crate) fn new(client: Client) -> Self {
Self { client }
Self { client: client.clone(), backend: Arc::new(OidcServer::new(client)) }
}
fn ctx(&self) -> &OidcCtx {
@@ -338,11 +336,6 @@ impl Oidc {
Ok(())
}
/// Get the `HttpService` to make requests with.
fn http_service(&self) -> HttpService {
HttpService::new(self.client.inner.http_client.clone())
}
/// The OpenID Connect authentication data.
///
/// Returns `None` if the client registration was not restored with
@@ -427,7 +420,7 @@ impl Oidc {
&self,
issuer: &str,
) -> Result<VerifiedProviderMetadata, OidcError> {
discover(&self.http_service(), issuer).await.map_err(Into::into)
self.backend.discover(issuer).await
}
/// Fetch the OpenID Connect metadata of the issuer.
@@ -440,14 +433,6 @@ impl Oidc {
self.given_provider_metadata(issuer).await
}
/// Fetch the OpenID Connect JSON Web Key Set at the given URI.
///
/// Returns an error if the client registration was not restored, or if an
/// error occurred when fetching the data.
async fn fetch_jwks(&self, jwks_uri: &Url) -> Result<PublicJsonWebKeySet, OidcError> {
Ok(fetch_jwks(&self.http_service(), jwks_uri).await?)
}
/// The OpenID Connect metadata of this client used during registration.
///
/// Returns `None` if the client registration was not restored with
@@ -654,20 +639,16 @@ impl Oidc {
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError> {
let provider_metadata = self.given_provider_metadata(issuer).await?;
let provider_metadata = self.backend.discover(issuer).await?;
let registration_endpoint = provider_metadata
.registration_endpoint
.as_ref()
.ok_or(OidcError::NoRegistrationSupport)?;
register_client(
&self.http_service(),
registration_endpoint,
client_metadata,
software_statement,
)
.await
.map_err(Into::into)
self.backend
.register_client(registration_endpoint, client_metadata, software_statement)
.await
}
/// Set the data of a client that is registered with an OpenID Connect
@@ -1007,44 +988,32 @@ impl Oidc {
/// Returns an error if a request fails.
pub async fn finish_authorization(
&self,
code: AuthorizationCode,
) -> Result<AccessTokenResponse, OidcError> {
let provider_metadata = self.provider_metadata().await?;
auth_code: AuthorizationCode,
) -> Result<(), OidcError> {
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
let validation_data = data
.authorization_data
.lock()
.await
.remove(&code.state)
.remove(&auth_code.state)
.ok_or(OidcError::InvalidState)?;
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
let id_token_verification_data = JwtVerificationData {
issuer: provider_metadata.issuer(),
jwks: &jwks,
client_id: &data.credentials.client_id().to_owned(),
signing_algorithm: data.metadata.id_token_signed_response_alg(),
};
let provider_metadata = self.provider_metadata().await?;
let (response, id_token) = access_token_with_authorization_code(
&self.http_service(),
data.credentials.clone(),
provider_metadata.token_endpoint(),
code.code,
validation_data,
Some(id_token_verification_data),
Utc::now(),
&mut rng()?,
)
.await?;
let session_tokens = self
.backend
.trade_authorization_code_for_tokens(
provider_metadata,
data.credentials.clone(),
data.metadata.clone(),
auth_code,
validation_data,
)
.await?;
self.set_session_tokens(OidcSessionTokens {
access_token: response.access_token.clone(),
refresh_token: response.refresh_token.clone(),
latest_id_token: id_token,
});
self.set_session_tokens(session_tokens);
Ok(response)
Ok(())
}
/// Abort the authorization process.
@@ -1074,52 +1043,44 @@ impl Oidc {
latest_id_token: Option<IdToken<'static>>,
lock: Option<CrossProcessRefreshLockGuard>,
) -> Result<(), OidcError> {
let provider_metadata = self.provider_metadata().await?;
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
trace!("Token refresh: attempting to refresh with refresh_token {}", hash(&refresh_token));
// Do not interrupt refresh access token requests and processing, by detaching
// the request sending and response processing.
let provider_metadata = self.provider_metadata().await?;
let this = self.clone();
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
let credentials = data.credentials.clone();
let signing_algorithm = data.metadata.id_token_signed_response_alg().clone();
let metadata = data.metadata.clone();
spawn(async move {
let id_token_verification_data = JwtVerificationData {
issuer: provider_metadata.issuer(),
jwks: &jwks,
client_id: &credentials.client_id().to_owned(),
signing_algorithm: &signing_algorithm,
};
tracing::trace!(
"Token refresh: attempting to refresh with refresh_token {}",
hash(&refresh_token)
);
match refresh_access_token(
&this.http_service(),
credentials,
provider_metadata.token_endpoint(),
refresh_token.clone(),
None,
Some(id_token_verification_data),
latest_id_token.as_ref(),
Utc::now(),
&mut rng()?,
)
.await
.map_err(OidcError::from)
match this
.backend
.refresh_access_token(
provider_metadata,
credentials,
&metadata,
refresh_token.clone(),
latest_id_token.clone(),
)
.await
.map_err(OidcError::from)
{
Ok((response, _id_token)) => {
Ok(new_tokens) => {
trace!(
"Token refresh: new refresh_token: {:?} / access_token: {}",
response.refresh_token.as_ref().map(hash),
hash(&response.access_token)
new_tokens.refresh_token.as_ref().map(hash),
hash(&new_tokens.access_token)
);
let tokens = OidcSessionTokens {
access_token: response.access_token,
refresh_token: response.refresh_token.clone().or(Some(refresh_token)),
access_token: new_tokens.access_token,
refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)),
latest_id_token,
};
@@ -1260,32 +1221,27 @@ impl Oidc {
provider_metadata.revocation_endpoint.as_ref().ok_or(OidcError::NoRevocationSupport)?;
let tokens = self.session_tokens().ok_or(OidcError::NotAuthenticated)?;
let mut rng = rng()?;
// Revoke the access token.
revoke_token(
&self.http_service(),
client_credentials.clone(),
revocation_endpoint,
tokens.access_token,
Some(OAuthTokenTypeHint::AccessToken),
Utc::now(),
&mut rng,
)
.await?;
self.backend
.revoke_token(
client_credentials.clone(),
revocation_endpoint,
tokens.access_token,
Some(OAuthTokenTypeHint::AccessToken),
)
.await?;
// Revoke the refresh token, if any.
if let Some(refresh_token) = tokens.refresh_token {
revoke_token(
&self.http_service(),
client_credentials.clone(),
revocation_endpoint,
refresh_token,
Some(OAuthTokenTypeHint::RefreshToken),
Utc::now(),
&mut rng,
)
.await?;
self.backend
.revoke_token(
client_credentials.clone(),
revocation_endpoint,
refresh_token,
Some(OAuthTokenTypeHint::RefreshToken),
)
.await?;
}
let end_session_builder =
@@ -1522,105 +1478,3 @@ fn hash<T: Hash>(x: &T) -> u64 {
x.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use mas_oidc_client::types::{
client_credentials::ClientCredentials, registration::ClientMetadata,
};
use matrix_sdk_base::SessionMeta;
use matrix_sdk_test::async_test;
use ruma::{
api::client::discovery::discover_homeserver::AuthenticationServerInfo, OwnedUserId,
};
use url::Url;
use crate::{
oidc::{OidcAccountManagementAction, OidcSession, OidcSessionTokens, UserSession},
ClientBuilder,
};
#[async_test]
async fn test_account_management_url() {
let builder =
ClientBuilder::new().homeserver_url(Url::parse("https://example.com").unwrap());
let client = builder.build().await.unwrap();
client
.restore_session(OidcSession {
credentials: ClientCredentials::None { client_id: "client_id".to_owned() },
metadata: ClientMetadata {
redirect_uris: Some(vec![Url::parse("https://example.com/login").unwrap()]),
..Default::default()
}
.validate()
.unwrap(),
user: UserSession {
meta: SessionMeta {
user_id: OwnedUserId::from_str("@user:example.com").unwrap(),
device_id: "device_id".into(),
},
tokens: OidcSessionTokens {
access_token: "access_token".to_owned(),
refresh_token: None,
latest_id_token: None,
},
issuer_info: AuthenticationServerInfo::new(
"https://example.com".to_owned(),
Some("https://example.com/account".to_owned()),
),
},
})
.await
.unwrap();
assert_eq!(
client.oidc().account_management_url(None).unwrap(),
Some(Url::parse("https://example.com/account").unwrap())
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::Profile))
.unwrap(),
Some(Url::parse("https://example.com/account?action=profile").unwrap())
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::SessionsList))
.unwrap(),
Some(Url::parse("https://example.com/account?action=sessions_list").unwrap())
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::SessionView {
device_id: "my_phone".into()
}))
.unwrap(),
Some(
Url::parse("https://example.com/account?action=session_view&device_id=my_phone")
.unwrap()
)
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::SessionEnd {
device_id: "my_old_phone".into()
}))
.unwrap(),
Some(
Url::parse("https://example.com/account?action=session_end&device_id=my_old_phone")
.unwrap()
)
);
}
}

View File

@@ -0,0 +1,299 @@
use std::sync::Arc;
use anyhow::Context as _;
use assert_matches::assert_matches;
use mas_oidc_client::{
requests::authorization_code::AuthorizationValidationData,
types::{
client_credentials::ClientCredentials,
errors::ClientErrorCode,
iana::oauth::OAuthClientAuthenticationMethod,
registration::{ClientMetadata, VerifiedClientMetadata},
},
};
use matrix_sdk_base::SessionMeta;
#[cfg(test)]
use matrix_sdk_test::async_test;
use ruma::{api::client::discovery::discover_homeserver::AuthenticationServerInfo, owned_user_id};
use url::Url;
use super::{
backend::mock::{MockImpl, AUTHORIZATION_URL, ISSUER_URL},
AuthorizationCode, AuthorizationError, AuthorizationResponse, Oidc,
OidcAccountManagementAction, OidcError, OidcSession, OidcSessionTokens,
RedirectUriQueryParseError, UserSession,
};
use crate::test_utils::test_client_builder;
const CLIENT_ID: &str = "test_client_id";
pub fn mock_registered_client_data() -> (ClientCredentials, VerifiedClientMetadata) {
(
ClientCredentials::None { client_id: CLIENT_ID.to_owned() },
ClientMetadata {
redirect_uris: Some(vec![]), // empty vector is ok lol
token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None),
..ClientMetadata::default()
}
.validate()
.expect("validate client metadata"),
)
}
pub fn mock_session(tokens: OidcSessionTokens) -> OidcSession {
let (credentials, metadata) = mock_registered_client_data();
OidcSession {
credentials,
metadata,
user: UserSession {
meta: SessionMeta {
user_id: ruma::user_id!("@u:e.uk").to_owned(),
device_id: ruma::device_id!("XYZ").to_owned(),
},
tokens,
issuer_info: AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None),
},
}
}
#[async_test]
async fn test_account_management_url() {
let builder = test_client_builder(Some("https://example.com".to_owned()));
let client = builder.build().await.unwrap();
client
.restore_session(OidcSession {
credentials: ClientCredentials::None { client_id: "client_id".to_owned() },
metadata: ClientMetadata {
redirect_uris: Some(vec![Url::parse("https://example.com/login").unwrap()]),
..Default::default()
}
.validate()
.unwrap(),
user: UserSession {
meta: SessionMeta {
user_id: owned_user_id!("@user:example.com"),
device_id: "device_id".into(),
},
tokens: OidcSessionTokens {
access_token: "access_token".to_owned(),
refresh_token: None,
latest_id_token: None,
},
issuer_info: AuthenticationServerInfo::new(
"https://example.com".to_owned(),
Some("https://example.com/account".to_owned()),
),
},
})
.await
.unwrap();
assert_eq!(
client.oidc().account_management_url(None).unwrap(),
Some(Url::parse("https://example.com/account").unwrap())
);
assert_eq!(
client.oidc().account_management_url(Some(OidcAccountManagementAction::Profile)).unwrap(),
Some(Url::parse("https://example.com/account?action=profile").unwrap())
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::SessionsList))
.unwrap(),
Some(Url::parse("https://example.com/account?action=sessions_list").unwrap())
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::SessionView {
device_id: "my_phone".into()
}))
.unwrap(),
Some(
Url::parse("https://example.com/account?action=session_view&device_id=my_phone")
.unwrap()
)
);
assert_eq!(
client
.oidc()
.account_management_url(Some(OidcAccountManagementAction::SessionEnd {
device_id: "my_old_phone".into()
}))
.unwrap(),
Some(
Url::parse("https://example.com/account?action=session_end&device_id=my_old_phone")
.unwrap()
)
);
}
#[async_test]
async fn test_login() -> anyhow::Result<()> {
let client = test_client_builder(None).build().await?;
let device_id = "D3V1C31D".to_owned(); // yo this is 1999 speaking
let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) };
let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None);
let (client_credentials, client_metadata) = mock_registered_client_data();
oidc.restore_registered_client(issuer_info, client_metadata, client_credentials);
let redirect_uri_str = "http://matrix.example.com/oidc/callback";
let redirect_uri = Url::parse(redirect_uri_str)?;
let mut authorization_data = oidc.login(redirect_uri, Some(device_id.clone()))?.build().await?;
tracing::debug!("authorization data URL = {}", authorization_data.url);
let mut num_expected = 6;
let mut nonce = None;
for (key, val) in authorization_data.url.query_pairs() {
match &*key {
"response_type" => {
assert_eq!(val, "code");
num_expected -= 1;
}
"client_id" => {
assert_eq!(val, CLIENT_ID);
num_expected -= 1;
}
"redirect_uri" => {
assert_eq!(val, redirect_uri_str);
num_expected -= 1;
}
"scope" => {
assert_eq!(val, format!("openid urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id}"));
num_expected -= 1;
}
"state" => {
num_expected -= 1;
assert_eq!(val, authorization_data.state);
}
"nonce" => {
num_expected -= 1;
nonce = Some(val);
}
_ => panic!("unexpected query parameter: {key}={val}"),
}
}
assert_eq!(num_expected, 0);
let data = oidc.data().unwrap();
let authorization_data_guard = data.authorization_data.lock().await;
let state = authorization_data_guard.get(&authorization_data.state).context("missing state")?;
let nonce = nonce.context("missing nonce")?;
assert_eq!(nonce, state.nonce);
authorization_data.url.set_query(None);
assert_eq!(authorization_data.url, Url::parse(AUTHORIZATION_URL).unwrap(),);
Ok(())
}
#[test]
fn test_authorization_response() -> anyhow::Result<()> {
let uri = Url::parse("https://example.com")?;
assert_matches!(
AuthorizationResponse::parse_uri(&uri),
Err(RedirectUriQueryParseError::MissingQuery)
);
let uri = Url::parse("https://example.com?code=123&state=456")?;
assert_matches!(
AuthorizationResponse::parse_uri(&uri),
Ok(AuthorizationResponse::Success(AuthorizationCode { code, state })) => {
assert_eq!(code, "123");
assert_eq!(state, "456");
}
);
let uri = Url::parse("https://example.com?error=invalid_grant&state=456")?;
assert_matches!(
AuthorizationResponse::parse_uri(&uri),
Ok(AuthorizationResponse::Error(AuthorizationError { error, state })) => {
assert_eq!(error.error, ClientErrorCode::InvalidGrant);
assert_eq!(error.error_description, None);
assert_eq!(state, "456");
}
);
Ok(())
}
#[async_test]
async fn test_finish_authorization() -> anyhow::Result<()> {
let client = test_client_builder(None).build().await?;
let session_tokens = OidcSessionTokens {
access_token: "4cc3ss".to_owned(),
refresh_token: Some("r3fr3$h".to_owned()),
latest_id_token: None,
};
let oidc = Oidc {
client: client.clone(),
backend: Arc::new(MockImpl::new().next_session_tokens(session_tokens.clone())),
};
let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None);
let (client_credentials, client_metadata) = mock_registered_client_data();
oidc.restore_registered_client(issuer_info, client_metadata, client_credentials);
// If the state is missing, then any attempt to finish authorizing will fail.
let res = oidc
.finish_authorization(AuthorizationCode { code: "42".to_owned(), state: "none".to_owned() })
.await;
assert_matches!(res, Err(OidcError::InvalidState));
assert!(oidc.session_tokens().is_none());
// Assuming a non-empty state "123"...
let state = "state".to_owned();
let redirect_uri = "http://matrix.example.com/oidc/callback";
let auth_validation_data = AuthorizationValidationData {
state: state.clone(),
nonce: "nonce".to_owned(),
redirect_uri: Url::parse(redirect_uri)?,
code_challenge_verifier: None,
};
{
let data = oidc.data().context("missing data")?;
let prev = data.authorization_data.lock().await.insert(state.clone(), {
AuthorizationValidationData { ..auth_validation_data.clone() }
});
assert!(prev.is_none());
}
// Finishing the authorization for another state won't work.
let res = oidc
.finish_authorization(AuthorizationCode {
code: "1337".to_owned(),
state: "none".to_owned(),
})
.await;
assert_matches!(res, Err(OidcError::InvalidState));
assert!(oidc.session_tokens().is_none());
assert!(oidc.data().unwrap().authorization_data.lock().await.get(&state).is_some());
// Finishing the authorization for the expected state will work.
oidc.finish_authorization(AuthorizationCode { code: "1337".to_owned(), state: state.clone() })
.await?;
assert_eq!(oidc.session_tokens(), Some(session_tokens));
assert!(oidc.data().unwrap().authorization_data.lock().await.get(&state).is_none());
Ok(())
}