diff --git a/Cargo.lock b/Cargo.lock index 24d61fefa..4564b1a04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3073,6 +3073,7 @@ dependencies = [ "serde", "serde_html_form", "serde_json", + "serde_urlencoded", "tempfile", "thiserror", "tokio", diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index f23699de1..d14b6fbff 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -144,6 +144,7 @@ dirs = "5.0.1" futures-executor = { workspace = true } matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base", default_features = false, features = ["testing"] } matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test" } +serde_urlencoded = "0.7.1" tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] diff --git a/crates/matrix-sdk/src/oidc/auth_code_builder.rs b/crates/matrix-sdk/src/oidc/auth_code_builder.rs index 51bea3ea4..29c9a4fe4 100644 --- a/crates/matrix-sdk/src/oidc/auth_code_builder.rs +++ b/crates/matrix-sdk/src/oidc/auth_code_builder.rs @@ -14,12 +14,9 @@ use std::{collections::HashSet, num::NonZeroU32}; -use chrono::Utc; use language_tags::LanguageTag; use mas_oidc_client::{ - requests::authorization_code::{ - build_authorization_url, build_par_authorization_url, AuthorizationRequestData, - }, + requests::authorization_code::{build_authorization_url, AuthorizationRequestData}, types::{ requests::{Display, Prompt}, scope::Scope, @@ -187,16 +184,15 @@ impl OidcAuthCodeUrlBuilder { let client_credentials = oidc.client_credentials().ok_or(OidcError::NotAuthenticated)?; - let res = build_par_authorization_url( - &oidc.http_service(), - client_credentials.clone(), - par_endpoint, - authorization_endpoint.clone(), - authorization_data.clone(), - Utc::now(), - &mut rng, - ) - .await; + let res = oidc + .backend + .build_par_authorization_url( + client_credentials.clone(), + par_endpoint, + authorization_endpoint.clone(), + authorization_data.clone(), + ) + .await; match res { Ok(res) => res, @@ -209,8 +205,8 @@ impl OidcAuthCodeUrlBuilder { // If the client said that PAR should be enforced, we should not try without // it, so just return the error. - if client_metadata.require_pushed_authorization_requests.unwrap_or_default() { - return Err(error.into()); + if client_metadata.require_pushed_authorization_requests.unwrap_or(false) { + return Err(error); } error!( diff --git a/crates/matrix-sdk/src/oidc/backend/mock.rs b/crates/matrix-sdk/src/oidc/backend/mock.rs new file mode 100644 index 000000000..8aef89b03 --- /dev/null +++ b/crates/matrix-sdk/src/oidc/backend/mock.rs @@ -0,0 +1,192 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for that specific language governing permissions and +// limitations under the License. + +//! Test implementation of the OIDC backend. + +use std::sync::{Arc, Mutex}; + +use http::StatusCode; +use mas_oidc_client::{ + error::{ + DiscoveryError, Error as OidcClientError, ErrorBody as OidcErrorBody, + HttpError as OidcHttpError, TokenRefreshError, TokenRequestError, + }, + requests::authorization_code::{AuthorizationRequestData, AuthorizationValidationData}, + types::{ + client_credentials::ClientCredentials, + errors::ClientErrorCode, + iana::oauth::OAuthTokenTypeHint, + oidc::{ProviderMetadata, VerifiedProviderMetadata}, + registration::{ClientRegistrationResponse, VerifiedClientMetadata}, + IdToken, + }, +}; +use url::Url; + +use super::{OidcBackend, OidcError, RefreshedSessionTokens}; +use crate::oidc::{AuthorizationCode, OidcSessionTokens}; + +pub(crate) const ISSUER_URL: &str = "https://oidc.example.com/issuer"; +pub(crate) const AUTHORIZATION_URL: &str = "https://oidc.example.com/authorization"; +pub(crate) const REVOCATION_URL: &str = "https://oidc.example.com/revocation"; +pub(crate) const TOKEN_URL: &str = "https://oidc.example.com/token"; +pub(crate) const JWKS_URL: &str = "https://oidc.example.com/jwks"; + +#[derive(Debug)] +pub(crate) struct MockImpl { + /// Must be an HTTPS URL. + issuer: String, + + /// Must be an HTTPS URL. + authorization_endpoint: String, + + /// Must be an HTTPS URL. + token_endpoint: String, + + /// Must be an HTTPS URL. + jwks_uri: String, + + /// Must be an HTTPS URL. + revocation_endpoint: String, + + /// The next session tokens that will be returned by a login or refresh. + next_session_tokens: Option, + + /// The next refresh token that's expected for a refresh. + expected_refresh_token: Option, + + /// Number of refreshes that effectively happened. + pub num_refreshes: Arc>, + + /// Tokens that have been revoked with `revoke_token`. + pub revoked_tokens: Arc>>, +} + +impl MockImpl { + pub fn new() -> Self { + Self { + issuer: ISSUER_URL.to_owned(), + authorization_endpoint: AUTHORIZATION_URL.to_owned(), + token_endpoint: TOKEN_URL.to_owned(), + jwks_uri: JWKS_URL.to_owned(), + revocation_endpoint: REVOCATION_URL.to_owned(), + next_session_tokens: None, + expected_refresh_token: None, + num_refreshes: Default::default(), + revoked_tokens: Default::default(), + } + } + + pub fn next_session_tokens(mut self, next_session_tokens: OidcSessionTokens) -> Self { + self.next_session_tokens = Some(next_session_tokens); + self + } + + pub fn expected_refresh_token(mut self, refresh_token: String) -> Self { + self.expected_refresh_token = Some(refresh_token); + self + } +} + +#[async_trait::async_trait] +impl OidcBackend for MockImpl { + async fn discover(&self, issuer: &str) -> Result { + Ok(ProviderMetadata { + issuer: Some(self.issuer.clone()), + authorization_endpoint: Some(Url::parse(&self.authorization_endpoint).unwrap()), + revocation_endpoint: Some(Url::parse(&self.revocation_endpoint).unwrap()), + token_endpoint: Some(Url::parse(&self.token_endpoint).unwrap()), + jwks_uri: Some(Url::parse(&self.jwks_uri).unwrap()), + response_types_supported: Some(vec![]), + subject_types_supported: Some(vec![]), + id_token_signing_alg_values_supported: Some(vec![]), + ..Default::default() + } + .validate(issuer) + .map_err(DiscoveryError::from)?) + } + + async fn trade_authorization_code_for_tokens( + &self, + _provider_metadata: VerifiedProviderMetadata, + _credentials: ClientCredentials, + _metadata: VerifiedClientMetadata, + _auth_code: AuthorizationCode, + _validation_data: AuthorizationValidationData, + ) -> Result { + Ok(self + .next_session_tokens + .as_ref() + .expect("missing next session tokens in testing") + .clone()) + } + + async fn register_client( + &self, + _registration_endpoint: &Url, + _client_metadata: VerifiedClientMetadata, + _software_statement: Option, + ) -> Result { + todo!() + } + + async fn build_par_authorization_url( + &self, + _client_credentials: ClientCredentials, + _par_endpoint: &Url, + _authorization_endpoint: Url, + _authorization_data: AuthorizationRequestData, + ) -> Result<(Url, AuthorizationValidationData), OidcError> { + todo!() + } + + async fn revoke_token( + &self, + _client_credentials: ClientCredentials, + _revocation_endpoint: &Url, + token: String, + _token_type_hint: Option, + ) -> Result<(), OidcError> { + self.revoked_tokens.lock().unwrap().push(token); + Ok(()) + } + + async fn refresh_access_token( + &self, + _provider_metadata: VerifiedProviderMetadata, + _credentials: ClientCredentials, + _metadata: &VerifiedClientMetadata, + refresh_token: String, + _latest_id_token: Option>, + ) -> Result { + if Some(refresh_token) != self.expected_refresh_token { + Err(OidcError::Oidc(OidcClientError::TokenRefresh(TokenRefreshError::Token( + TokenRequestError::Http(OidcHttpError { + body: Some(OidcErrorBody { + error: ClientErrorCode::InvalidGrant, + error_description: None, + }), + status: StatusCode::from_u16(400).unwrap(), + }), + )))) + } else { + *self.num_refreshes.lock().unwrap() += 1; + let next_tokens = self.next_session_tokens.clone().expect("missing next tokens"); + Ok(RefreshedSessionTokens { + access_token: next_tokens.access_token, + refresh_token: next_tokens.refresh_token, + }) + } + } +} diff --git a/crates/matrix-sdk/src/oidc/backend/mod.rs b/crates/matrix-sdk/src/oidc/backend/mod.rs new file mode 100644 index 000000000..f76ae0e9c --- /dev/null +++ b/crates/matrix-sdk/src/oidc/backend/mod.rs @@ -0,0 +1,87 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for that specific language governing permissions and +// limitations under the License. + +//! Trait for defining an implementation for an OIDC backend. +//! +//! Used mostly for testing purposes. + +use mas_oidc_client::{ + requests::authorization_code::{AuthorizationRequestData, AuthorizationValidationData}, + types::{ + client_credentials::ClientCredentials, + iana::oauth::OAuthTokenTypeHint, + oidc::VerifiedProviderMetadata, + registration::{ClientRegistrationResponse, VerifiedClientMetadata}, + IdToken, + }, +}; +use url::Url; + +use super::{AuthorizationCode, OidcError, OidcSessionTokens}; + +pub(crate) mod server; + +#[cfg(test)] +pub(crate) mod mock; + +pub(super) struct RefreshedSessionTokens { + pub access_token: String, + pub refresh_token: Option, +} + +#[async_trait::async_trait] +pub(super) trait OidcBackend: std::fmt::Debug + Send + Sync { + async fn discover(&self, issuer: &str) -> Result; + + async fn register_client( + &self, + registration_endpoint: &Url, + client_metadata: VerifiedClientMetadata, + software_statement: Option, + ) -> Result; + + async fn trade_authorization_code_for_tokens( + &self, + provider_metadata: VerifiedProviderMetadata, + credentials: ClientCredentials, + metadata: VerifiedClientMetadata, + auth_code: AuthorizationCode, + validation_data: AuthorizationValidationData, + ) -> Result; + + async fn refresh_access_token( + &self, + provider_metadata: VerifiedProviderMetadata, + credentials: ClientCredentials, + metadata: &VerifiedClientMetadata, + refresh_token: String, + latest_id_token: Option>, + ) -> Result; + + async fn build_par_authorization_url( + &self, + client_credentials: ClientCredentials, + par_endpoint: &Url, + authorization_endpoint: Url, + authorization_data: AuthorizationRequestData, + ) -> Result<(Url, AuthorizationValidationData), OidcError>; + + async fn revoke_token( + &self, + client_credentials: ClientCredentials, + revocation_endpoint: &Url, + token: String, + token_type_hint: Option, + ) -> Result<(), OidcError>; +} diff --git a/crates/matrix-sdk/src/oidc/backend/server.rs b/crates/matrix-sdk/src/oidc/backend/server.rs new file mode 100644 index 000000000..0163e6137 --- /dev/null +++ b/crates/matrix-sdk/src/oidc/backend/server.rs @@ -0,0 +1,203 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for that specific language governing permissions and +// limitations under the License. + +//! Actual implementation of the OIDC backend, using the mas_oidc_client +//! implementation. + +use chrono::Utc; +use mas_oidc_client::{ + http_service::HttpService, + jose::jwk::PublicJsonWebKeySet, + requests::{ + authorization_code::{ + access_token_with_authorization_code, build_par_authorization_url, + AuthorizationRequestData, AuthorizationValidationData, + }, + discovery::discover, + jose::{fetch_jwks, JwtVerificationData}, + refresh_token::refresh_access_token, + registration::register_client, + revocation::revoke_token, + }, + types::{ + client_credentials::ClientCredentials, + iana::oauth::OAuthTokenTypeHint, + oidc::VerifiedProviderMetadata, + registration::{ClientRegistrationResponse, VerifiedClientMetadata}, + IdToken, + }, +}; +use url::Url; + +use super::{OidcBackend, OidcError, RefreshedSessionTokens}; +use crate::{ + oidc::{rng, AuthorizationCode, OidcSessionTokens}, + Client, +}; + +#[derive(Debug)] +pub(crate) struct OidcServer { + client: Client, +} + +impl OidcServer { + pub(crate) fn new(client: Client) -> Self { + Self { client } + } + + fn http_service(&self) -> HttpService { + HttpService::new(self.client.inner.http_client.clone()) + } + + /// Fetch the OpenID Connect JSON Web Key Set at the given URI. + /// + /// Returns an error if the client registration was not restored, or if an + /// error occurred when fetching the data. + async fn fetch_jwks(&self, uri: &Url) -> Result { + fetch_jwks(&self.http_service(), uri).await.map_err(Into::into) + } +} + +#[async_trait::async_trait] +impl OidcBackend for OidcServer { + async fn discover(&self, issuer: &str) -> Result { + discover(&self.http_service(), issuer).await.map_err(Into::into) + } + + async fn trade_authorization_code_for_tokens( + &self, + provider_metadata: VerifiedProviderMetadata, + credentials: ClientCredentials, + metadata: VerifiedClientMetadata, + auth_code: AuthorizationCode, + validation_data: AuthorizationValidationData, + ) -> Result { + let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?; + + let id_token_verification_data = JwtVerificationData { + issuer: provider_metadata.issuer(), + jwks: &jwks, + client_id: &credentials.client_id().to_owned(), + signing_algorithm: metadata.id_token_signed_response_alg(), + }; + + let (response, id_token) = access_token_with_authorization_code( + &self.http_service(), + credentials.clone(), + provider_metadata.token_endpoint(), + auth_code.code, + validation_data, + Some(id_token_verification_data), + Utc::now(), + &mut rng()?, + ) + .await?; + + Ok(OidcSessionTokens { + access_token: response.access_token, + refresh_token: response.refresh_token, + latest_id_token: id_token, + }) + } + + async fn refresh_access_token( + &self, + provider_metadata: VerifiedProviderMetadata, + credentials: ClientCredentials, + metadata: &VerifiedClientMetadata, + refresh_token: String, + latest_id_token: Option>, + ) -> Result { + let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?; + + let id_token_verification_data = JwtVerificationData { + issuer: provider_metadata.issuer(), + jwks: &jwks, + client_id: &credentials.client_id().to_owned(), + signing_algorithm: &metadata.id_token_signed_response_alg().clone(), + }; + + refresh_access_token( + &self.http_service(), + credentials, + provider_metadata.token_endpoint(), + refresh_token, + None, + Some(id_token_verification_data), + latest_id_token.as_ref(), + Utc::now(), + &mut rng()?, + ) + .await + .map(|(response, _id_token)| RefreshedSessionTokens { + access_token: response.access_token, + refresh_token: response.refresh_token, + }) + .map_err(Into::into) + } + + async fn register_client( + &self, + registration_endpoint: &Url, + client_metadata: VerifiedClientMetadata, + software_statement: Option, + ) -> Result { + register_client( + &self.http_service(), + registration_endpoint, + client_metadata, + software_statement, + ) + .await + .map_err(Into::into) + } + + async fn build_par_authorization_url( + &self, + client_credentials: ClientCredentials, + par_endpoint: &Url, + authorization_endpoint: Url, + authorization_data: AuthorizationRequestData, + ) -> Result<(Url, AuthorizationValidationData), OidcError> { + Ok(build_par_authorization_url( + &self.http_service(), + client_credentials, + par_endpoint, + authorization_endpoint, + authorization_data, + Utc::now(), + &mut rng()?, + ) + .await?) + } + + async fn revoke_token( + &self, + client_credentials: ClientCredentials, + revocation_endpoint: &Url, + token: String, + token_type_hint: Option, + ) -> Result<(), OidcError> { + Ok(revoke_token( + &self.http_service(), + client_credentials, + revocation_endpoint, + token, + token_type_hint, + Utc::now(), + &mut rng()?, + ) + .await?) + } +} diff --git a/crates/matrix-sdk/src/oidc/cross_process.rs b/crates/matrix-sdk/src/oidc/cross_process.rs index 06e46698d..7a78c0d3d 100644 --- a/crates/matrix-sdk/src/oidc/cross_process.rs +++ b/crates/matrix-sdk/src/oidc/cross_process.rs @@ -241,47 +241,37 @@ pub enum CrossProcessRefreshLockError { DuplicatedLock, } -#[cfg(test)] +#[cfg(all(test, feature = "e2e-encryption"))] mod tests { - use mas_oidc_client::types::{ - client_credentials::ClientCredentials, iana::oauth::OAuthClientAuthenticationMethod, - registration::ClientMetadata, - }; + use std::sync::Arc; + + use anyhow::Context as _; + use futures_util::future::join_all; use matrix_sdk_base::SessionMeta; use matrix_sdk_test::async_test; - use ruma::api::client::discovery::discover_homeserver::AuthenticationServerInfo; + use ruma::{ + api::client::discovery::discover_homeserver::AuthenticationServerInfo, owned_device_id, + owned_user_id, + }; + use wiremock::{ + matchers::{method, path}, + Mock, MockServer, ResponseTemplate, + }; use super::compute_session_hash; use crate::{ - oidc::{OidcSession, OidcSessionTokens, UserSession}, + oidc::{ + backend::mock::{MockImpl, ISSUER_URL}, + tests, + tests::mock_registered_client_data, + Oidc, OidcSessionTokens, + }, test_utils::test_client_builder, Error, }; - fn fake_session(tokens: OidcSessionTokens) -> OidcSession { - OidcSession { - credentials: ClientCredentials::None { client_id: "test_client_id".to_owned() }, - metadata: ClientMetadata { - redirect_uris: Some(vec![]), // empty vector is ok lol - token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None), - ..ClientMetadata::default() - } - .validate() - .expect("validate client metadata"), - user: UserSession { - meta: SessionMeta { - user_id: ruma::user_id!("@u:e.uk").to_owned(), - device_id: ruma::device_id!("XYZ").to_owned(), - }, - tokens, - issuer_info: AuthenticationServerInfo::new("issuer".to_owned(), None), - }, - } - } - - #[cfg(feature = "e2e-encryption")] #[async_test] - async fn test_oidc_restore_session_lock() -> Result<(), Error> { + async fn test_restore_session_lock() -> Result<(), Error> { // Create a client that will use sqlite databases. let tmp_dir = tempfile::tempdir()?; @@ -305,7 +295,7 @@ mod tests { )?; let session_hash = compute_session_hash(&tokens); - client.oidc().restore_session(fake_session(tokens.clone())).await?; + client.oidc().restore_session(tests::mock_session(tokens.clone())).await?; assert_eq!(client.oidc().session_tokens().unwrap(), tokens); @@ -325,4 +315,173 @@ mod tests { Ok(()) } + + #[async_test] + async fn test_finish_login() -> anyhow::Result<()> { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/_matrix/client/r0/account/whoami")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "user_id": "@joe:example.org", + "device_id": "D3V1C31D", + }))) + .expect(1) + .named("`GET /whoami` good token") + .mount(&server) + .await; + + let tmp_dir = tempfile::tempdir()?; + let client = + test_client_builder(Some(server.uri())).sqlite_store(tmp_dir, None).build().await?; + + let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) }; + + // Restore registered client. + let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None); + let (client_credentials, client_metadata) = mock_registered_client_data(); + oidc.restore_registered_client(issuer_info, client_metadata, client_credentials); + + // Enable cross-process lock. + oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?; + + // Simulate we've done finalize_authorization / restore_session before. + let session_tokens = OidcSessionTokens { + access_token: "access".to_owned(), + refresh_token: Some("refresh".to_owned()), + latest_id_token: None, + }; + oidc.set_session_tokens(session_tokens.clone()); + + // Now, finishing logging will get the user and device ids. + oidc.finish_login().await?; + + let session_meta = client.session_meta().context("should have session meta now")?; + assert_eq!( + *session_meta, + SessionMeta { + user_id: owned_user_id!("@joe:example.org"), + device_id: owned_device_id!("D3V1C31D") + } + ); + + { + // The cross process lock has been correctly updated, and the next attempt to + // take it won't result in a mismatch. + let xp_manager = + oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + let actual_hash = compute_session_hash(&session_tokens); + assert_eq!(guard.db_hash, Some(actual_hash)); + assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash)); + assert!(!guard.hash_mismatch); + } + + Ok(()) + } + + #[async_test] + async fn test_refresh_access_token_twice() -> anyhow::Result<()> { + // This tests that refresh token works, and that it doesn't cause multiple token + // refreshes whenever one spawns two refreshes around the same time. + + let tmp_dir = tempfile::tempdir()?; + let client = test_client_builder(None).sqlite_store(tmp_dir, None).build().await?; + + let prev_tokens = OidcSessionTokens { + access_token: "prev-access-token".to_owned(), + refresh_token: Some("prev-refresh-token".to_owned()), + latest_id_token: None, + }; + + let next_tokens = OidcSessionTokens { + access_token: "next-access-token".to_owned(), + refresh_token: Some("next-refresh-token".to_owned()), + latest_id_token: None, + }; + + let backend = Arc::new( + MockImpl::new() + .next_session_tokens(next_tokens.clone()) + .expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()), + ); + let oidc = Oidc { client: client.clone(), backend: backend.clone() }; + + // Enable cross-process lock. + oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?; + + // Restore the session. + oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?; + + // Immediately try to refresh the access token twice in parallel. + for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await { + result?; + } + + // There should have been at most one refresh. + assert_eq!(*backend.num_refreshes.lock().unwrap(), 1); + + { + // The cross process lock has been correctly updated, and the next attempt to + // take it won't result in a mismatch. + let xp_manager = + oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + let actual_hash = compute_session_hash(&next_tokens); + assert_eq!(guard.db_hash, Some(actual_hash)); + assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash)); + assert!(!guard.hash_mismatch); + } + + Ok(()) + } + + #[async_test] + async fn test_logout() -> anyhow::Result<()> { + let tmp_dir = tempfile::tempdir()?; + let client = test_client_builder(None).sqlite_store(tmp_dir, None).build().await?; + + let tokens = OidcSessionTokens { + access_token: "prev-access-token".to_owned(), + refresh_token: Some("prev-refresh-token".to_owned()), + latest_id_token: None, + }; + + let backend = Arc::new(MockImpl::new()); + let oidc = Oidc { client: client.clone(), backend: backend.clone() }; + + // Enable cross-process lock. + oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?; + + // Restore the session. + oidc.restore_session(tests::mock_session(tokens.clone())).await?; + + let end_session_builder = oidc.logout().await?; + + // No end session builder because our test impl doesn't provide an end session + // endpoint. + assert!(end_session_builder.is_none()); + + // Both the access token and the refresh tokens have been invalidated. + { + let revoked = backend.revoked_tokens.lock().unwrap(); + assert_eq!(revoked.len(), 2); + assert_eq!( + *revoked, + vec![tokens.access_token.clone(), tokens.refresh_token.clone().unwrap(),] + ); + } + + { + // The cross process lock has been correctly updated, and all the hashes are + // empty after a logout. + let xp_manager = + oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?; + let guard = xp_manager.spin_lock().await?; + assert!(guard.db_hash.is_none()); + assert!(guard.hash_guard.is_none()); + assert!(!guard.hash_mismatch); + } + + Ok(()) + } } diff --git a/crates/matrix-sdk/src/oidc/mod.rs b/crates/matrix-sdk/src/oidc/mod.rs index e3d60edd2..80c0528a3 100644 --- a/crates/matrix-sdk/src/oidc/mod.rs +++ b/crates/matrix-sdk/src/oidc/mod.rs @@ -173,28 +173,17 @@ use std::{ }; use as_variant::as_variant; -use chrono::Utc; use eyeball::SharedObservable; use futures_core::Stream; pub use mas_oidc_client::{error, types}; use mas_oidc_client::{ - http_service::HttpService, - jose::jwk::PublicJsonWebKeySet, - requests::{ - authorization_code::{access_token_with_authorization_code, AuthorizationValidationData}, - discovery::discover, - jose::{fetch_jwks, JwtVerificationData}, - refresh_token::refresh_access_token, - registration::register_client, - revocation::revoke_token, - }, + requests::authorization_code::AuthorizationValidationData, types::{ client_credentials::ClientCredentials, errors::ClientError, iana::oauth::OAuthTokenTypeHint, oidc::VerifiedProviderMetadata, registration::{ClientRegistrationResponse, VerifiedClientMetadata}, - requests::AccessTokenResponse, scope::{MatrixApiScopeToken, Scope, ScopeToken}, IdToken, }, @@ -209,17 +198,23 @@ use tracing::{error, trace, warn}; use url::Url; mod auth_code_builder; +mod backend; mod cross_process; mod data_serde; mod end_session_builder; +#[cfg(test)] +mod tests; -use self::cross_process::{ - CrossProcessRefreshLockError, CrossProcessRefreshLockGuard, CrossProcessRefreshManager, -}; pub use self::{ auth_code_builder::{OidcAuthCodeUrlBuilder, OidcAuthorizationData}, end_session_builder::{OidcEndSessionData, OidcEndSessionUrlBuilder}, }; +use self::{ + backend::{server::OidcServer, OidcBackend}, + cross_process::{ + CrossProcessRefreshLockError, CrossProcessRefreshLockGuard, CrossProcessRefreshManager, + }, +}; use crate::{authentication::AuthData, client::SessionChange, Client, RefreshTokenError, Result}; pub(crate) struct OidcCtx { @@ -268,11 +263,14 @@ impl fmt::Debug for OidcAuthData { pub struct Oidc { /// The underlying Matrix API client. client: Client, + + /// The implementation of the OIDC backend. + backend: Arc, } impl Oidc { pub(crate) fn new(client: Client) -> Self { - Self { client } + Self { client: client.clone(), backend: Arc::new(OidcServer::new(client)) } } fn ctx(&self) -> &OidcCtx { @@ -338,11 +336,6 @@ impl Oidc { Ok(()) } - /// Get the `HttpService` to make requests with. - fn http_service(&self) -> HttpService { - HttpService::new(self.client.inner.http_client.clone()) - } - /// The OpenID Connect authentication data. /// /// Returns `None` if the client registration was not restored with @@ -427,7 +420,7 @@ impl Oidc { &self, issuer: &str, ) -> Result { - discover(&self.http_service(), issuer).await.map_err(Into::into) + self.backend.discover(issuer).await } /// Fetch the OpenID Connect metadata of the issuer. @@ -440,14 +433,6 @@ impl Oidc { self.given_provider_metadata(issuer).await } - /// Fetch the OpenID Connect JSON Web Key Set at the given URI. - /// - /// Returns an error if the client registration was not restored, or if an - /// error occurred when fetching the data. - async fn fetch_jwks(&self, jwks_uri: &Url) -> Result { - Ok(fetch_jwks(&self.http_service(), jwks_uri).await?) - } - /// The OpenID Connect metadata of this client used during registration. /// /// Returns `None` if the client registration was not restored with @@ -654,20 +639,16 @@ impl Oidc { client_metadata: VerifiedClientMetadata, software_statement: Option, ) -> Result { - let provider_metadata = self.given_provider_metadata(issuer).await?; + let provider_metadata = self.backend.discover(issuer).await?; + let registration_endpoint = provider_metadata .registration_endpoint .as_ref() .ok_or(OidcError::NoRegistrationSupport)?; - register_client( - &self.http_service(), - registration_endpoint, - client_metadata, - software_statement, - ) - .await - .map_err(Into::into) + self.backend + .register_client(registration_endpoint, client_metadata, software_statement) + .await } /// Set the data of a client that is registered with an OpenID Connect @@ -1007,44 +988,32 @@ impl Oidc { /// Returns an error if a request fails. pub async fn finish_authorization( &self, - code: AuthorizationCode, - ) -> Result { - let provider_metadata = self.provider_metadata().await?; + auth_code: AuthorizationCode, + ) -> Result<(), OidcError> { let data = self.data().ok_or(OidcError::NotAuthenticated)?; let validation_data = data .authorization_data .lock() .await - .remove(&code.state) + .remove(&auth_code.state) .ok_or(OidcError::InvalidState)?; - let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?; - let id_token_verification_data = JwtVerificationData { - issuer: provider_metadata.issuer(), - jwks: &jwks, - client_id: &data.credentials.client_id().to_owned(), - signing_algorithm: data.metadata.id_token_signed_response_alg(), - }; + let provider_metadata = self.provider_metadata().await?; - let (response, id_token) = access_token_with_authorization_code( - &self.http_service(), - data.credentials.clone(), - provider_metadata.token_endpoint(), - code.code, - validation_data, - Some(id_token_verification_data), - Utc::now(), - &mut rng()?, - ) - .await?; + let session_tokens = self + .backend + .trade_authorization_code_for_tokens( + provider_metadata, + data.credentials.clone(), + data.metadata.clone(), + auth_code, + validation_data, + ) + .await?; - self.set_session_tokens(OidcSessionTokens { - access_token: response.access_token.clone(), - refresh_token: response.refresh_token.clone(), - latest_id_token: id_token, - }); + self.set_session_tokens(session_tokens); - Ok(response) + Ok(()) } /// Abort the authorization process. @@ -1074,52 +1043,44 @@ impl Oidc { latest_id_token: Option>, lock: Option, ) -> Result<(), OidcError> { - let provider_metadata = self.provider_metadata().await?; - let data = self.data().ok_or(OidcError::NotAuthenticated)?; - - let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?; - - trace!("Token refresh: attempting to refresh with refresh_token {}", hash(&refresh_token)); - // Do not interrupt refresh access token requests and processing, by detaching // the request sending and response processing. + let provider_metadata = self.provider_metadata().await?; + let this = self.clone(); + let data = self.data().ok_or(OidcError::NotAuthenticated)?; let credentials = data.credentials.clone(); - let signing_algorithm = data.metadata.id_token_signed_response_alg().clone(); + let metadata = data.metadata.clone(); spawn(async move { - let id_token_verification_data = JwtVerificationData { - issuer: provider_metadata.issuer(), - jwks: &jwks, - client_id: &credentials.client_id().to_owned(), - signing_algorithm: &signing_algorithm, - }; + tracing::trace!( + "Token refresh: attempting to refresh with refresh_token {}", + hash(&refresh_token) + ); - match refresh_access_token( - &this.http_service(), - credentials, - provider_metadata.token_endpoint(), - refresh_token.clone(), - None, - Some(id_token_verification_data), - latest_id_token.as_ref(), - Utc::now(), - &mut rng()?, - ) - .await - .map_err(OidcError::from) + match this + .backend + .refresh_access_token( + provider_metadata, + credentials, + &metadata, + refresh_token.clone(), + latest_id_token.clone(), + ) + .await + .map_err(OidcError::from) { - Ok((response, _id_token)) => { + Ok(new_tokens) => { trace!( "Token refresh: new refresh_token: {:?} / access_token: {}", - response.refresh_token.as_ref().map(hash), - hash(&response.access_token) + new_tokens.refresh_token.as_ref().map(hash), + hash(&new_tokens.access_token) ); let tokens = OidcSessionTokens { - access_token: response.access_token, - refresh_token: response.refresh_token.clone().or(Some(refresh_token)), + access_token: new_tokens.access_token, + refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)), latest_id_token, }; @@ -1260,32 +1221,27 @@ impl Oidc { provider_metadata.revocation_endpoint.as_ref().ok_or(OidcError::NoRevocationSupport)?; let tokens = self.session_tokens().ok_or(OidcError::NotAuthenticated)?; - let mut rng = rng()?; // Revoke the access token. - revoke_token( - &self.http_service(), - client_credentials.clone(), - revocation_endpoint, - tokens.access_token, - Some(OAuthTokenTypeHint::AccessToken), - Utc::now(), - &mut rng, - ) - .await?; + self.backend + .revoke_token( + client_credentials.clone(), + revocation_endpoint, + tokens.access_token, + Some(OAuthTokenTypeHint::AccessToken), + ) + .await?; // Revoke the refresh token, if any. if let Some(refresh_token) = tokens.refresh_token { - revoke_token( - &self.http_service(), - client_credentials.clone(), - revocation_endpoint, - refresh_token, - Some(OAuthTokenTypeHint::RefreshToken), - Utc::now(), - &mut rng, - ) - .await?; + self.backend + .revoke_token( + client_credentials.clone(), + revocation_endpoint, + refresh_token, + Some(OAuthTokenTypeHint::RefreshToken), + ) + .await?; } let end_session_builder = @@ -1522,105 +1478,3 @@ fn hash(x: &T) -> u64 { x.hash(&mut hasher); hasher.finish() } - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use mas_oidc_client::types::{ - client_credentials::ClientCredentials, registration::ClientMetadata, - }; - use matrix_sdk_base::SessionMeta; - use matrix_sdk_test::async_test; - use ruma::{ - api::client::discovery::discover_homeserver::AuthenticationServerInfo, OwnedUserId, - }; - use url::Url; - - use crate::{ - oidc::{OidcAccountManagementAction, OidcSession, OidcSessionTokens, UserSession}, - ClientBuilder, - }; - - #[async_test] - async fn test_account_management_url() { - let builder = - ClientBuilder::new().homeserver_url(Url::parse("https://example.com").unwrap()); - let client = builder.build().await.unwrap(); - - client - .restore_session(OidcSession { - credentials: ClientCredentials::None { client_id: "client_id".to_owned() }, - metadata: ClientMetadata { - redirect_uris: Some(vec![Url::parse("https://example.com/login").unwrap()]), - ..Default::default() - } - .validate() - .unwrap(), - user: UserSession { - meta: SessionMeta { - user_id: OwnedUserId::from_str("@user:example.com").unwrap(), - device_id: "device_id".into(), - }, - tokens: OidcSessionTokens { - access_token: "access_token".to_owned(), - refresh_token: None, - latest_id_token: None, - }, - issuer_info: AuthenticationServerInfo::new( - "https://example.com".to_owned(), - Some("https://example.com/account".to_owned()), - ), - }, - }) - .await - .unwrap(); - - assert_eq!( - client.oidc().account_management_url(None).unwrap(), - Some(Url::parse("https://example.com/account").unwrap()) - ); - - assert_eq!( - client - .oidc() - .account_management_url(Some(OidcAccountManagementAction::Profile)) - .unwrap(), - Some(Url::parse("https://example.com/account?action=profile").unwrap()) - ); - - assert_eq!( - client - .oidc() - .account_management_url(Some(OidcAccountManagementAction::SessionsList)) - .unwrap(), - Some(Url::parse("https://example.com/account?action=sessions_list").unwrap()) - ); - - assert_eq!( - client - .oidc() - .account_management_url(Some(OidcAccountManagementAction::SessionView { - device_id: "my_phone".into() - })) - .unwrap(), - Some( - Url::parse("https://example.com/account?action=session_view&device_id=my_phone") - .unwrap() - ) - ); - - assert_eq!( - client - .oidc() - .account_management_url(Some(OidcAccountManagementAction::SessionEnd { - device_id: "my_old_phone".into() - })) - .unwrap(), - Some( - Url::parse("https://example.com/account?action=session_end&device_id=my_old_phone") - .unwrap() - ) - ); - } -} diff --git a/crates/matrix-sdk/src/oidc/tests.rs b/crates/matrix-sdk/src/oidc/tests.rs new file mode 100644 index 000000000..072a6bc14 --- /dev/null +++ b/crates/matrix-sdk/src/oidc/tests.rs @@ -0,0 +1,299 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use assert_matches::assert_matches; +use mas_oidc_client::{ + requests::authorization_code::AuthorizationValidationData, + types::{ + client_credentials::ClientCredentials, + errors::ClientErrorCode, + iana::oauth::OAuthClientAuthenticationMethod, + registration::{ClientMetadata, VerifiedClientMetadata}, + }, +}; +use matrix_sdk_base::SessionMeta; +#[cfg(test)] +use matrix_sdk_test::async_test; +use ruma::{api::client::discovery::discover_homeserver::AuthenticationServerInfo, owned_user_id}; +use url::Url; + +use super::{ + backend::mock::{MockImpl, AUTHORIZATION_URL, ISSUER_URL}, + AuthorizationCode, AuthorizationError, AuthorizationResponse, Oidc, + OidcAccountManagementAction, OidcError, OidcSession, OidcSessionTokens, + RedirectUriQueryParseError, UserSession, +}; +use crate::test_utils::test_client_builder; + +const CLIENT_ID: &str = "test_client_id"; + +pub fn mock_registered_client_data() -> (ClientCredentials, VerifiedClientMetadata) { + ( + ClientCredentials::None { client_id: CLIENT_ID.to_owned() }, + ClientMetadata { + redirect_uris: Some(vec![]), // empty vector is ok lol + token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None), + ..ClientMetadata::default() + } + .validate() + .expect("validate client metadata"), + ) +} + +pub fn mock_session(tokens: OidcSessionTokens) -> OidcSession { + let (credentials, metadata) = mock_registered_client_data(); + OidcSession { + credentials, + metadata, + user: UserSession { + meta: SessionMeta { + user_id: ruma::user_id!("@u:e.uk").to_owned(), + device_id: ruma::device_id!("XYZ").to_owned(), + }, + tokens, + issuer_info: AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None), + }, + } +} + +#[async_test] +async fn test_account_management_url() { + let builder = test_client_builder(Some("https://example.com".to_owned())); + let client = builder.build().await.unwrap(); + + client + .restore_session(OidcSession { + credentials: ClientCredentials::None { client_id: "client_id".to_owned() }, + + metadata: ClientMetadata { + redirect_uris: Some(vec![Url::parse("https://example.com/login").unwrap()]), + ..Default::default() + } + .validate() + .unwrap(), + + user: UserSession { + meta: SessionMeta { + user_id: owned_user_id!("@user:example.com"), + device_id: "device_id".into(), + }, + tokens: OidcSessionTokens { + access_token: "access_token".to_owned(), + refresh_token: None, + latest_id_token: None, + }, + issuer_info: AuthenticationServerInfo::new( + "https://example.com".to_owned(), + Some("https://example.com/account".to_owned()), + ), + }, + }) + .await + .unwrap(); + + assert_eq!( + client.oidc().account_management_url(None).unwrap(), + Some(Url::parse("https://example.com/account").unwrap()) + ); + + assert_eq!( + client.oidc().account_management_url(Some(OidcAccountManagementAction::Profile)).unwrap(), + Some(Url::parse("https://example.com/account?action=profile").unwrap()) + ); + + assert_eq!( + client + .oidc() + .account_management_url(Some(OidcAccountManagementAction::SessionsList)) + .unwrap(), + Some(Url::parse("https://example.com/account?action=sessions_list").unwrap()) + ); + + assert_eq!( + client + .oidc() + .account_management_url(Some(OidcAccountManagementAction::SessionView { + device_id: "my_phone".into() + })) + .unwrap(), + Some( + Url::parse("https://example.com/account?action=session_view&device_id=my_phone") + .unwrap() + ) + ); + + assert_eq!( + client + .oidc() + .account_management_url(Some(OidcAccountManagementAction::SessionEnd { + device_id: "my_old_phone".into() + })) + .unwrap(), + Some( + Url::parse("https://example.com/account?action=session_end&device_id=my_old_phone") + .unwrap() + ) + ); +} + +#[async_test] +async fn test_login() -> anyhow::Result<()> { + let client = test_client_builder(None).build().await?; + + let device_id = "D3V1C31D".to_owned(); // yo this is 1999 speaking + + let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) }; + + let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None); + let (client_credentials, client_metadata) = mock_registered_client_data(); + oidc.restore_registered_client(issuer_info, client_metadata, client_credentials); + + let redirect_uri_str = "http://matrix.example.com/oidc/callback"; + let redirect_uri = Url::parse(redirect_uri_str)?; + let mut authorization_data = oidc.login(redirect_uri, Some(device_id.clone()))?.build().await?; + + tracing::debug!("authorization data URL = {}", authorization_data.url); + + let mut num_expected = 6; + let mut nonce = None; + + for (key, val) in authorization_data.url.query_pairs() { + match &*key { + "response_type" => { + assert_eq!(val, "code"); + num_expected -= 1; + } + "client_id" => { + assert_eq!(val, CLIENT_ID); + num_expected -= 1; + } + "redirect_uri" => { + assert_eq!(val, redirect_uri_str); + num_expected -= 1; + } + "scope" => { + assert_eq!(val, format!("openid urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id}")); + num_expected -= 1; + } + "state" => { + num_expected -= 1; + assert_eq!(val, authorization_data.state); + } + "nonce" => { + num_expected -= 1; + nonce = Some(val); + } + _ => panic!("unexpected query parameter: {key}={val}"), + } + } + + assert_eq!(num_expected, 0); + + let data = oidc.data().unwrap(); + let authorization_data_guard = data.authorization_data.lock().await; + + let state = authorization_data_guard.get(&authorization_data.state).context("missing state")?; + let nonce = nonce.context("missing nonce")?; + assert_eq!(nonce, state.nonce); + + authorization_data.url.set_query(None); + assert_eq!(authorization_data.url, Url::parse(AUTHORIZATION_URL).unwrap(),); + + Ok(()) +} + +#[test] +fn test_authorization_response() -> anyhow::Result<()> { + let uri = Url::parse("https://example.com")?; + assert_matches!( + AuthorizationResponse::parse_uri(&uri), + Err(RedirectUriQueryParseError::MissingQuery) + ); + + let uri = Url::parse("https://example.com?code=123&state=456")?; + assert_matches!( + AuthorizationResponse::parse_uri(&uri), + Ok(AuthorizationResponse::Success(AuthorizationCode { code, state })) => { + assert_eq!(code, "123"); + assert_eq!(state, "456"); + } + ); + + let uri = Url::parse("https://example.com?error=invalid_grant&state=456")?; + assert_matches!( + AuthorizationResponse::parse_uri(&uri), + Ok(AuthorizationResponse::Error(AuthorizationError { error, state })) => { + assert_eq!(error.error, ClientErrorCode::InvalidGrant); + assert_eq!(error.error_description, None); + assert_eq!(state, "456"); + } + ); + + Ok(()) +} + +#[async_test] +async fn test_finish_authorization() -> anyhow::Result<()> { + let client = test_client_builder(None).build().await?; + + let session_tokens = OidcSessionTokens { + access_token: "4cc3ss".to_owned(), + refresh_token: Some("r3fr3$h".to_owned()), + latest_id_token: None, + }; + let oidc = Oidc { + client: client.clone(), + backend: Arc::new(MockImpl::new().next_session_tokens(session_tokens.clone())), + }; + + let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None); + let (client_credentials, client_metadata) = mock_registered_client_data(); + oidc.restore_registered_client(issuer_info, client_metadata, client_credentials); + + // If the state is missing, then any attempt to finish authorizing will fail. + let res = oidc + .finish_authorization(AuthorizationCode { code: "42".to_owned(), state: "none".to_owned() }) + .await; + + assert_matches!(res, Err(OidcError::InvalidState)); + assert!(oidc.session_tokens().is_none()); + + // Assuming a non-empty state "123"... + let state = "state".to_owned(); + let redirect_uri = "http://matrix.example.com/oidc/callback"; + let auth_validation_data = AuthorizationValidationData { + state: state.clone(), + nonce: "nonce".to_owned(), + redirect_uri: Url::parse(redirect_uri)?, + code_challenge_verifier: None, + }; + + { + let data = oidc.data().context("missing data")?; + let prev = data.authorization_data.lock().await.insert(state.clone(), { + AuthorizationValidationData { ..auth_validation_data.clone() } + }); + assert!(prev.is_none()); + } + + // Finishing the authorization for another state won't work. + let res = oidc + .finish_authorization(AuthorizationCode { + code: "1337".to_owned(), + state: "none".to_owned(), + }) + .await; + + assert_matches!(res, Err(OidcError::InvalidState)); + assert!(oidc.session_tokens().is_none()); + assert!(oidc.data().unwrap().authorization_data.lock().await.get(&state).is_some()); + + // Finishing the authorization for the expected state will work. + oidc.finish_authorization(AuthorizationCode { code: "1337".to_owned(), state: state.clone() }) + .await?; + + assert_eq!(oidc.session_tokens(), Some(session_tokens)); + assert!(oidc.data().unwrap().authorization_data.lock().await.get(&state).is_none()); + + Ok(()) +}