mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-09 16:34:32 -04:00
Add tests for OIDC (#2558)
* chore(oidc): put impl of OIDC server behind a trait and add tests for the OIDC flow
chore(oidc): add first tests for login/AuthorizationResponse
chore(oidc): add tests for finish_authorization/finish_login/refresh_access_token
chore(oidc): add test for logout
chore(🤷): clippy/fmt
* chore(oidc): move account management test to the new `tests` module
* chore(oidc): address review comments
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3073,6 +3073,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_html_form",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
|
||||
@@ -144,6 +144,7 @@ dirs = "5.0.1"
|
||||
futures-executor = { workspace = true }
|
||||
matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base", default_features = false, features = ["testing"] }
|
||||
matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test" }
|
||||
serde_urlencoded = "0.7.1"
|
||||
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
||||
|
||||
@@ -14,12 +14,9 @@
|
||||
|
||||
use std::{collections::HashSet, num::NonZeroU32};
|
||||
|
||||
use chrono::Utc;
|
||||
use language_tags::LanguageTag;
|
||||
use mas_oidc_client::{
|
||||
requests::authorization_code::{
|
||||
build_authorization_url, build_par_authorization_url, AuthorizationRequestData,
|
||||
},
|
||||
requests::authorization_code::{build_authorization_url, AuthorizationRequestData},
|
||||
types::{
|
||||
requests::{Display, Prompt},
|
||||
scope::Scope,
|
||||
@@ -187,16 +184,15 @@ impl OidcAuthCodeUrlBuilder {
|
||||
let client_credentials =
|
||||
oidc.client_credentials().ok_or(OidcError::NotAuthenticated)?;
|
||||
|
||||
let res = build_par_authorization_url(
|
||||
&oidc.http_service(),
|
||||
client_credentials.clone(),
|
||||
par_endpoint,
|
||||
authorization_endpoint.clone(),
|
||||
authorization_data.clone(),
|
||||
Utc::now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await;
|
||||
let res = oidc
|
||||
.backend
|
||||
.build_par_authorization_url(
|
||||
client_credentials.clone(),
|
||||
par_endpoint,
|
||||
authorization_endpoint.clone(),
|
||||
authorization_data.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(res) => res,
|
||||
@@ -209,8 +205,8 @@ impl OidcAuthCodeUrlBuilder {
|
||||
|
||||
// If the client said that PAR should be enforced, we should not try without
|
||||
// it, so just return the error.
|
||||
if client_metadata.require_pushed_authorization_requests.unwrap_or_default() {
|
||||
return Err(error.into());
|
||||
if client_metadata.require_pushed_authorization_requests.unwrap_or(false) {
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
error!(
|
||||
|
||||
192
crates/matrix-sdk/src/oidc/backend/mock.rs
Normal file
192
crates/matrix-sdk/src/oidc/backend/mock.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for that specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Test implementation of the OIDC backend.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use http::StatusCode;
|
||||
use mas_oidc_client::{
|
||||
error::{
|
||||
DiscoveryError, Error as OidcClientError, ErrorBody as OidcErrorBody,
|
||||
HttpError as OidcHttpError, TokenRefreshError, TokenRequestError,
|
||||
},
|
||||
requests::authorization_code::{AuthorizationRequestData, AuthorizationValidationData},
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
errors::ClientErrorCode,
|
||||
iana::oauth::OAuthTokenTypeHint,
|
||||
oidc::{ProviderMetadata, VerifiedProviderMetadata},
|
||||
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
|
||||
IdToken,
|
||||
},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::{OidcBackend, OidcError, RefreshedSessionTokens};
|
||||
use crate::oidc::{AuthorizationCode, OidcSessionTokens};
|
||||
|
||||
pub(crate) const ISSUER_URL: &str = "https://oidc.example.com/issuer";
|
||||
pub(crate) const AUTHORIZATION_URL: &str = "https://oidc.example.com/authorization";
|
||||
pub(crate) const REVOCATION_URL: &str = "https://oidc.example.com/revocation";
|
||||
pub(crate) const TOKEN_URL: &str = "https://oidc.example.com/token";
|
||||
pub(crate) const JWKS_URL: &str = "https://oidc.example.com/jwks";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct MockImpl {
|
||||
/// Must be an HTTPS URL.
|
||||
issuer: String,
|
||||
|
||||
/// Must be an HTTPS URL.
|
||||
authorization_endpoint: String,
|
||||
|
||||
/// Must be an HTTPS URL.
|
||||
token_endpoint: String,
|
||||
|
||||
/// Must be an HTTPS URL.
|
||||
jwks_uri: String,
|
||||
|
||||
/// Must be an HTTPS URL.
|
||||
revocation_endpoint: String,
|
||||
|
||||
/// The next session tokens that will be returned by a login or refresh.
|
||||
next_session_tokens: Option<OidcSessionTokens>,
|
||||
|
||||
/// The next refresh token that's expected for a refresh.
|
||||
expected_refresh_token: Option<String>,
|
||||
|
||||
/// Number of refreshes that effectively happened.
|
||||
pub num_refreshes: Arc<Mutex<u32>>,
|
||||
|
||||
/// Tokens that have been revoked with `revoke_token`.
|
||||
pub revoked_tokens: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
impl MockImpl {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
issuer: ISSUER_URL.to_owned(),
|
||||
authorization_endpoint: AUTHORIZATION_URL.to_owned(),
|
||||
token_endpoint: TOKEN_URL.to_owned(),
|
||||
jwks_uri: JWKS_URL.to_owned(),
|
||||
revocation_endpoint: REVOCATION_URL.to_owned(),
|
||||
next_session_tokens: None,
|
||||
expected_refresh_token: None,
|
||||
num_refreshes: Default::default(),
|
||||
revoked_tokens: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_session_tokens(mut self, next_session_tokens: OidcSessionTokens) -> Self {
|
||||
self.next_session_tokens = Some(next_session_tokens);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn expected_refresh_token(mut self, refresh_token: String) -> Self {
|
||||
self.expected_refresh_token = Some(refresh_token);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OidcBackend for MockImpl {
|
||||
async fn discover(&self, issuer: &str) -> Result<VerifiedProviderMetadata, OidcError> {
|
||||
Ok(ProviderMetadata {
|
||||
issuer: Some(self.issuer.clone()),
|
||||
authorization_endpoint: Some(Url::parse(&self.authorization_endpoint).unwrap()),
|
||||
revocation_endpoint: Some(Url::parse(&self.revocation_endpoint).unwrap()),
|
||||
token_endpoint: Some(Url::parse(&self.token_endpoint).unwrap()),
|
||||
jwks_uri: Some(Url::parse(&self.jwks_uri).unwrap()),
|
||||
response_types_supported: Some(vec![]),
|
||||
subject_types_supported: Some(vec![]),
|
||||
id_token_signing_alg_values_supported: Some(vec![]),
|
||||
..Default::default()
|
||||
}
|
||||
.validate(issuer)
|
||||
.map_err(DiscoveryError::from)?)
|
||||
}
|
||||
|
||||
async fn trade_authorization_code_for_tokens(
|
||||
&self,
|
||||
_provider_metadata: VerifiedProviderMetadata,
|
||||
_credentials: ClientCredentials,
|
||||
_metadata: VerifiedClientMetadata,
|
||||
_auth_code: AuthorizationCode,
|
||||
_validation_data: AuthorizationValidationData,
|
||||
) -> Result<OidcSessionTokens, OidcError> {
|
||||
Ok(self
|
||||
.next_session_tokens
|
||||
.as_ref()
|
||||
.expect("missing next session tokens in testing")
|
||||
.clone())
|
||||
}
|
||||
|
||||
async fn register_client(
|
||||
&self,
|
||||
_registration_endpoint: &Url,
|
||||
_client_metadata: VerifiedClientMetadata,
|
||||
_software_statement: Option<String>,
|
||||
) -> Result<ClientRegistrationResponse, OidcError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn build_par_authorization_url(
|
||||
&self,
|
||||
_client_credentials: ClientCredentials,
|
||||
_par_endpoint: &Url,
|
||||
_authorization_endpoint: Url,
|
||||
_authorization_data: AuthorizationRequestData,
|
||||
) -> Result<(Url, AuthorizationValidationData), OidcError> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn revoke_token(
|
||||
&self,
|
||||
_client_credentials: ClientCredentials,
|
||||
_revocation_endpoint: &Url,
|
||||
token: String,
|
||||
_token_type_hint: Option<OAuthTokenTypeHint>,
|
||||
) -> Result<(), OidcError> {
|
||||
self.revoked_tokens.lock().unwrap().push(token);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn refresh_access_token(
|
||||
&self,
|
||||
_provider_metadata: VerifiedProviderMetadata,
|
||||
_credentials: ClientCredentials,
|
||||
_metadata: &VerifiedClientMetadata,
|
||||
refresh_token: String,
|
||||
_latest_id_token: Option<IdToken<'static>>,
|
||||
) -> Result<RefreshedSessionTokens, OidcError> {
|
||||
if Some(refresh_token) != self.expected_refresh_token {
|
||||
Err(OidcError::Oidc(OidcClientError::TokenRefresh(TokenRefreshError::Token(
|
||||
TokenRequestError::Http(OidcHttpError {
|
||||
body: Some(OidcErrorBody {
|
||||
error: ClientErrorCode::InvalidGrant,
|
||||
error_description: None,
|
||||
}),
|
||||
status: StatusCode::from_u16(400).unwrap(),
|
||||
}),
|
||||
))))
|
||||
} else {
|
||||
*self.num_refreshes.lock().unwrap() += 1;
|
||||
let next_tokens = self.next_session_tokens.clone().expect("missing next tokens");
|
||||
Ok(RefreshedSessionTokens {
|
||||
access_token: next_tokens.access_token,
|
||||
refresh_token: next_tokens.refresh_token,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
87
crates/matrix-sdk/src/oidc/backend/mod.rs
Normal file
87
crates/matrix-sdk/src/oidc/backend/mod.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for that specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Trait for defining an implementation for an OIDC backend.
|
||||
//!
|
||||
//! Used mostly for testing purposes.
|
||||
|
||||
use mas_oidc_client::{
|
||||
requests::authorization_code::{AuthorizationRequestData, AuthorizationValidationData},
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
iana::oauth::OAuthTokenTypeHint,
|
||||
oidc::VerifiedProviderMetadata,
|
||||
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
|
||||
IdToken,
|
||||
},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::{AuthorizationCode, OidcError, OidcSessionTokens};
|
||||
|
||||
pub(crate) mod server;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod mock;
|
||||
|
||||
pub(super) struct RefreshedSessionTokens {
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub(super) trait OidcBackend: std::fmt::Debug + Send + Sync {
|
||||
async fn discover(&self, issuer: &str) -> Result<VerifiedProviderMetadata, OidcError>;
|
||||
|
||||
async fn register_client(
|
||||
&self,
|
||||
registration_endpoint: &Url,
|
||||
client_metadata: VerifiedClientMetadata,
|
||||
software_statement: Option<String>,
|
||||
) -> Result<ClientRegistrationResponse, OidcError>;
|
||||
|
||||
async fn trade_authorization_code_for_tokens(
|
||||
&self,
|
||||
provider_metadata: VerifiedProviderMetadata,
|
||||
credentials: ClientCredentials,
|
||||
metadata: VerifiedClientMetadata,
|
||||
auth_code: AuthorizationCode,
|
||||
validation_data: AuthorizationValidationData,
|
||||
) -> Result<OidcSessionTokens, OidcError>;
|
||||
|
||||
async fn refresh_access_token(
|
||||
&self,
|
||||
provider_metadata: VerifiedProviderMetadata,
|
||||
credentials: ClientCredentials,
|
||||
metadata: &VerifiedClientMetadata,
|
||||
refresh_token: String,
|
||||
latest_id_token: Option<IdToken<'static>>,
|
||||
) -> Result<RefreshedSessionTokens, OidcError>;
|
||||
|
||||
async fn build_par_authorization_url(
|
||||
&self,
|
||||
client_credentials: ClientCredentials,
|
||||
par_endpoint: &Url,
|
||||
authorization_endpoint: Url,
|
||||
authorization_data: AuthorizationRequestData,
|
||||
) -> Result<(Url, AuthorizationValidationData), OidcError>;
|
||||
|
||||
async fn revoke_token(
|
||||
&self,
|
||||
client_credentials: ClientCredentials,
|
||||
revocation_endpoint: &Url,
|
||||
token: String,
|
||||
token_type_hint: Option<OAuthTokenTypeHint>,
|
||||
) -> Result<(), OidcError>;
|
||||
}
|
||||
203
crates/matrix-sdk/src/oidc/backend/server.rs
Normal file
203
crates/matrix-sdk/src/oidc/backend/server.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for that specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Actual implementation of the OIDC backend, using the mas_oidc_client
|
||||
//! implementation.
|
||||
|
||||
use chrono::Utc;
|
||||
use mas_oidc_client::{
|
||||
http_service::HttpService,
|
||||
jose::jwk::PublicJsonWebKeySet,
|
||||
requests::{
|
||||
authorization_code::{
|
||||
access_token_with_authorization_code, build_par_authorization_url,
|
||||
AuthorizationRequestData, AuthorizationValidationData,
|
||||
},
|
||||
discovery::discover,
|
||||
jose::{fetch_jwks, JwtVerificationData},
|
||||
refresh_token::refresh_access_token,
|
||||
registration::register_client,
|
||||
revocation::revoke_token,
|
||||
},
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
iana::oauth::OAuthTokenTypeHint,
|
||||
oidc::VerifiedProviderMetadata,
|
||||
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
|
||||
IdToken,
|
||||
},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use super::{OidcBackend, OidcError, RefreshedSessionTokens};
|
||||
use crate::{
|
||||
oidc::{rng, AuthorizationCode, OidcSessionTokens},
|
||||
Client,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct OidcServer {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl OidcServer {
|
||||
pub(crate) fn new(client: Client) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
fn http_service(&self) -> HttpService {
|
||||
HttpService::new(self.client.inner.http_client.clone())
|
||||
}
|
||||
|
||||
/// Fetch the OpenID Connect JSON Web Key Set at the given URI.
|
||||
///
|
||||
/// Returns an error if the client registration was not restored, or if an
|
||||
/// error occurred when fetching the data.
|
||||
async fn fetch_jwks(&self, uri: &Url) -> Result<PublicJsonWebKeySet, OidcError> {
|
||||
fetch_jwks(&self.http_service(), uri).await.map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OidcBackend for OidcServer {
|
||||
async fn discover(&self, issuer: &str) -> Result<VerifiedProviderMetadata, OidcError> {
|
||||
discover(&self.http_service(), issuer).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn trade_authorization_code_for_tokens(
|
||||
&self,
|
||||
provider_metadata: VerifiedProviderMetadata,
|
||||
credentials: ClientCredentials,
|
||||
metadata: VerifiedClientMetadata,
|
||||
auth_code: AuthorizationCode,
|
||||
validation_data: AuthorizationValidationData,
|
||||
) -> Result<OidcSessionTokens, OidcError> {
|
||||
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
|
||||
|
||||
let id_token_verification_data = JwtVerificationData {
|
||||
issuer: provider_metadata.issuer(),
|
||||
jwks: &jwks,
|
||||
client_id: &credentials.client_id().to_owned(),
|
||||
signing_algorithm: metadata.id_token_signed_response_alg(),
|
||||
};
|
||||
|
||||
let (response, id_token) = access_token_with_authorization_code(
|
||||
&self.http_service(),
|
||||
credentials.clone(),
|
||||
provider_metadata.token_endpoint(),
|
||||
auth_code.code,
|
||||
validation_data,
|
||||
Some(id_token_verification_data),
|
||||
Utc::now(),
|
||||
&mut rng()?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(OidcSessionTokens {
|
||||
access_token: response.access_token,
|
||||
refresh_token: response.refresh_token,
|
||||
latest_id_token: id_token,
|
||||
})
|
||||
}
|
||||
|
||||
async fn refresh_access_token(
|
||||
&self,
|
||||
provider_metadata: VerifiedProviderMetadata,
|
||||
credentials: ClientCredentials,
|
||||
metadata: &VerifiedClientMetadata,
|
||||
refresh_token: String,
|
||||
latest_id_token: Option<IdToken<'static>>,
|
||||
) -> Result<RefreshedSessionTokens, OidcError> {
|
||||
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
|
||||
|
||||
let id_token_verification_data = JwtVerificationData {
|
||||
issuer: provider_metadata.issuer(),
|
||||
jwks: &jwks,
|
||||
client_id: &credentials.client_id().to_owned(),
|
||||
signing_algorithm: &metadata.id_token_signed_response_alg().clone(),
|
||||
};
|
||||
|
||||
refresh_access_token(
|
||||
&self.http_service(),
|
||||
credentials,
|
||||
provider_metadata.token_endpoint(),
|
||||
refresh_token,
|
||||
None,
|
||||
Some(id_token_verification_data),
|
||||
latest_id_token.as_ref(),
|
||||
Utc::now(),
|
||||
&mut rng()?,
|
||||
)
|
||||
.await
|
||||
.map(|(response, _id_token)| RefreshedSessionTokens {
|
||||
access_token: response.access_token,
|
||||
refresh_token: response.refresh_token,
|
||||
})
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn register_client(
|
||||
&self,
|
||||
registration_endpoint: &Url,
|
||||
client_metadata: VerifiedClientMetadata,
|
||||
software_statement: Option<String>,
|
||||
) -> Result<ClientRegistrationResponse, OidcError> {
|
||||
register_client(
|
||||
&self.http_service(),
|
||||
registration_endpoint,
|
||||
client_metadata,
|
||||
software_statement,
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn build_par_authorization_url(
|
||||
&self,
|
||||
client_credentials: ClientCredentials,
|
||||
par_endpoint: &Url,
|
||||
authorization_endpoint: Url,
|
||||
authorization_data: AuthorizationRequestData,
|
||||
) -> Result<(Url, AuthorizationValidationData), OidcError> {
|
||||
Ok(build_par_authorization_url(
|
||||
&self.http_service(),
|
||||
client_credentials,
|
||||
par_endpoint,
|
||||
authorization_endpoint,
|
||||
authorization_data,
|
||||
Utc::now(),
|
||||
&mut rng()?,
|
||||
)
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn revoke_token(
|
||||
&self,
|
||||
client_credentials: ClientCredentials,
|
||||
revocation_endpoint: &Url,
|
||||
token: String,
|
||||
token_type_hint: Option<OAuthTokenTypeHint>,
|
||||
) -> Result<(), OidcError> {
|
||||
Ok(revoke_token(
|
||||
&self.http_service(),
|
||||
client_credentials,
|
||||
revocation_endpoint,
|
||||
token,
|
||||
token_type_hint,
|
||||
Utc::now(),
|
||||
&mut rng()?,
|
||||
)
|
||||
.await?)
|
||||
}
|
||||
}
|
||||
@@ -241,47 +241,37 @@ pub enum CrossProcessRefreshLockError {
|
||||
DuplicatedLock,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(all(test, feature = "e2e-encryption"))]
|
||||
mod tests {
|
||||
use mas_oidc_client::types::{
|
||||
client_credentials::ClientCredentials, iana::oauth::OAuthClientAuthenticationMethod,
|
||||
registration::ClientMetadata,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures_util::future::join_all;
|
||||
use matrix_sdk_base::SessionMeta;
|
||||
use matrix_sdk_test::async_test;
|
||||
use ruma::api::client::discovery::discover_homeserver::AuthenticationServerInfo;
|
||||
use ruma::{
|
||||
api::client::discovery::discover_homeserver::AuthenticationServerInfo, owned_device_id,
|
||||
owned_user_id,
|
||||
};
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
Mock, MockServer, ResponseTemplate,
|
||||
};
|
||||
|
||||
use super::compute_session_hash;
|
||||
use crate::{
|
||||
oidc::{OidcSession, OidcSessionTokens, UserSession},
|
||||
oidc::{
|
||||
backend::mock::{MockImpl, ISSUER_URL},
|
||||
tests,
|
||||
tests::mock_registered_client_data,
|
||||
Oidc, OidcSessionTokens,
|
||||
},
|
||||
test_utils::test_client_builder,
|
||||
Error,
|
||||
};
|
||||
|
||||
fn fake_session(tokens: OidcSessionTokens) -> OidcSession {
|
||||
OidcSession {
|
||||
credentials: ClientCredentials::None { client_id: "test_client_id".to_owned() },
|
||||
metadata: ClientMetadata {
|
||||
redirect_uris: Some(vec![]), // empty vector is ok lol
|
||||
token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None),
|
||||
..ClientMetadata::default()
|
||||
}
|
||||
.validate()
|
||||
.expect("validate client metadata"),
|
||||
user: UserSession {
|
||||
meta: SessionMeta {
|
||||
user_id: ruma::user_id!("@u:e.uk").to_owned(),
|
||||
device_id: ruma::device_id!("XYZ").to_owned(),
|
||||
},
|
||||
tokens,
|
||||
issuer_info: AuthenticationServerInfo::new("issuer".to_owned(), None),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
#[async_test]
|
||||
async fn test_oidc_restore_session_lock() -> Result<(), Error> {
|
||||
async fn test_restore_session_lock() -> Result<(), Error> {
|
||||
// Create a client that will use sqlite databases.
|
||||
|
||||
let tmp_dir = tempfile::tempdir()?;
|
||||
@@ -305,7 +295,7 @@ mod tests {
|
||||
)?;
|
||||
|
||||
let session_hash = compute_session_hash(&tokens);
|
||||
client.oidc().restore_session(fake_session(tokens.clone())).await?;
|
||||
client.oidc().restore_session(tests::mock_session(tokens.clone())).await?;
|
||||
|
||||
assert_eq!(client.oidc().session_tokens().unwrap(), tokens);
|
||||
|
||||
@@ -325,4 +315,173 @@ mod tests {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_finish_login() -> anyhow::Result<()> {
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/_matrix/client/r0/account/whoami"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
|
||||
"user_id": "@joe:example.org",
|
||||
"device_id": "D3V1C31D",
|
||||
})))
|
||||
.expect(1)
|
||||
.named("`GET /whoami` good token")
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let tmp_dir = tempfile::tempdir()?;
|
||||
let client =
|
||||
test_client_builder(Some(server.uri())).sqlite_store(tmp_dir, None).build().await?;
|
||||
|
||||
let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) };
|
||||
|
||||
// Restore registered client.
|
||||
let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None);
|
||||
let (client_credentials, client_metadata) = mock_registered_client_data();
|
||||
oidc.restore_registered_client(issuer_info, client_metadata, client_credentials);
|
||||
|
||||
// Enable cross-process lock.
|
||||
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
|
||||
|
||||
// Simulate we've done finalize_authorization / restore_session before.
|
||||
let session_tokens = OidcSessionTokens {
|
||||
access_token: "access".to_owned(),
|
||||
refresh_token: Some("refresh".to_owned()),
|
||||
latest_id_token: None,
|
||||
};
|
||||
oidc.set_session_tokens(session_tokens.clone());
|
||||
|
||||
// Now, finishing logging will get the user and device ids.
|
||||
oidc.finish_login().await?;
|
||||
|
||||
let session_meta = client.session_meta().context("should have session meta now")?;
|
||||
assert_eq!(
|
||||
*session_meta,
|
||||
SessionMeta {
|
||||
user_id: owned_user_id!("@joe:example.org"),
|
||||
device_id: owned_device_id!("D3V1C31D")
|
||||
}
|
||||
);
|
||||
|
||||
{
|
||||
// The cross process lock has been correctly updated, and the next attempt to
|
||||
// take it won't result in a mismatch.
|
||||
let xp_manager =
|
||||
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
|
||||
let guard = xp_manager.spin_lock().await?;
|
||||
let actual_hash = compute_session_hash(&session_tokens);
|
||||
assert_eq!(guard.db_hash, Some(actual_hash));
|
||||
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
|
||||
assert!(!guard.hash_mismatch);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
|
||||
// This tests that refresh token works, and that it doesn't cause multiple token
|
||||
// refreshes whenever one spawns two refreshes around the same time.
|
||||
|
||||
let tmp_dir = tempfile::tempdir()?;
|
||||
let client = test_client_builder(None).sqlite_store(tmp_dir, None).build().await?;
|
||||
|
||||
let prev_tokens = OidcSessionTokens {
|
||||
access_token: "prev-access-token".to_owned(),
|
||||
refresh_token: Some("prev-refresh-token".to_owned()),
|
||||
latest_id_token: None,
|
||||
};
|
||||
|
||||
let next_tokens = OidcSessionTokens {
|
||||
access_token: "next-access-token".to_owned(),
|
||||
refresh_token: Some("next-refresh-token".to_owned()),
|
||||
latest_id_token: None,
|
||||
};
|
||||
|
||||
let backend = Arc::new(
|
||||
MockImpl::new()
|
||||
.next_session_tokens(next_tokens.clone())
|
||||
.expected_refresh_token(prev_tokens.refresh_token.clone().unwrap()),
|
||||
);
|
||||
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
|
||||
|
||||
// Enable cross-process lock.
|
||||
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
|
||||
|
||||
// Restore the session.
|
||||
oidc.restore_session(tests::mock_session(prev_tokens.clone())).await?;
|
||||
|
||||
// Immediately try to refresh the access token twice in parallel.
|
||||
for result in join_all([oidc.refresh_access_token(), oidc.refresh_access_token()]).await {
|
||||
result?;
|
||||
}
|
||||
|
||||
// There should have been at most one refresh.
|
||||
assert_eq!(*backend.num_refreshes.lock().unwrap(), 1);
|
||||
|
||||
{
|
||||
// The cross process lock has been correctly updated, and the next attempt to
|
||||
// take it won't result in a mismatch.
|
||||
let xp_manager =
|
||||
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
|
||||
let guard = xp_manager.spin_lock().await?;
|
||||
let actual_hash = compute_session_hash(&next_tokens);
|
||||
assert_eq!(guard.db_hash, Some(actual_hash));
|
||||
assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
|
||||
assert!(!guard.hash_mismatch);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_logout() -> anyhow::Result<()> {
|
||||
let tmp_dir = tempfile::tempdir()?;
|
||||
let client = test_client_builder(None).sqlite_store(tmp_dir, None).build().await?;
|
||||
|
||||
let tokens = OidcSessionTokens {
|
||||
access_token: "prev-access-token".to_owned(),
|
||||
refresh_token: Some("prev-refresh-token".to_owned()),
|
||||
latest_id_token: None,
|
||||
};
|
||||
|
||||
let backend = Arc::new(MockImpl::new());
|
||||
let oidc = Oidc { client: client.clone(), backend: backend.clone() };
|
||||
|
||||
// Enable cross-process lock.
|
||||
oidc.enable_cross_process_refresh_lock("lock".to_owned()).await?;
|
||||
|
||||
// Restore the session.
|
||||
oidc.restore_session(tests::mock_session(tokens.clone())).await?;
|
||||
|
||||
let end_session_builder = oidc.logout().await?;
|
||||
|
||||
// No end session builder because our test impl doesn't provide an end session
|
||||
// endpoint.
|
||||
assert!(end_session_builder.is_none());
|
||||
|
||||
// Both the access token and the refresh tokens have been invalidated.
|
||||
{
|
||||
let revoked = backend.revoked_tokens.lock().unwrap();
|
||||
assert_eq!(revoked.len(), 2);
|
||||
assert_eq!(
|
||||
*revoked,
|
||||
vec![tokens.access_token.clone(), tokens.refresh_token.clone().unwrap(),]
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// The cross process lock has been correctly updated, and all the hashes are
|
||||
// empty after a logout.
|
||||
let xp_manager =
|
||||
oidc.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
|
||||
let guard = xp_manager.spin_lock().await?;
|
||||
assert!(guard.db_hash.is_none());
|
||||
assert!(guard.hash_guard.is_none());
|
||||
assert!(!guard.hash_mismatch);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,28 +173,17 @@ use std::{
|
||||
};
|
||||
|
||||
use as_variant::as_variant;
|
||||
use chrono::Utc;
|
||||
use eyeball::SharedObservable;
|
||||
use futures_core::Stream;
|
||||
pub use mas_oidc_client::{error, types};
|
||||
use mas_oidc_client::{
|
||||
http_service::HttpService,
|
||||
jose::jwk::PublicJsonWebKeySet,
|
||||
requests::{
|
||||
authorization_code::{access_token_with_authorization_code, AuthorizationValidationData},
|
||||
discovery::discover,
|
||||
jose::{fetch_jwks, JwtVerificationData},
|
||||
refresh_token::refresh_access_token,
|
||||
registration::register_client,
|
||||
revocation::revoke_token,
|
||||
},
|
||||
requests::authorization_code::AuthorizationValidationData,
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
errors::ClientError,
|
||||
iana::oauth::OAuthTokenTypeHint,
|
||||
oidc::VerifiedProviderMetadata,
|
||||
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
|
||||
requests::AccessTokenResponse,
|
||||
scope::{MatrixApiScopeToken, Scope, ScopeToken},
|
||||
IdToken,
|
||||
},
|
||||
@@ -209,17 +198,23 @@ use tracing::{error, trace, warn};
|
||||
use url::Url;
|
||||
|
||||
mod auth_code_builder;
|
||||
mod backend;
|
||||
mod cross_process;
|
||||
mod data_serde;
|
||||
mod end_session_builder;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use self::cross_process::{
|
||||
CrossProcessRefreshLockError, CrossProcessRefreshLockGuard, CrossProcessRefreshManager,
|
||||
};
|
||||
pub use self::{
|
||||
auth_code_builder::{OidcAuthCodeUrlBuilder, OidcAuthorizationData},
|
||||
end_session_builder::{OidcEndSessionData, OidcEndSessionUrlBuilder},
|
||||
};
|
||||
use self::{
|
||||
backend::{server::OidcServer, OidcBackend},
|
||||
cross_process::{
|
||||
CrossProcessRefreshLockError, CrossProcessRefreshLockGuard, CrossProcessRefreshManager,
|
||||
},
|
||||
};
|
||||
use crate::{authentication::AuthData, client::SessionChange, Client, RefreshTokenError, Result};
|
||||
|
||||
pub(crate) struct OidcCtx {
|
||||
@@ -268,11 +263,14 @@ impl fmt::Debug for OidcAuthData {
|
||||
pub struct Oidc {
|
||||
/// The underlying Matrix API client.
|
||||
client: Client,
|
||||
|
||||
/// The implementation of the OIDC backend.
|
||||
backend: Arc<dyn OidcBackend>,
|
||||
}
|
||||
|
||||
impl Oidc {
|
||||
pub(crate) fn new(client: Client) -> Self {
|
||||
Self { client }
|
||||
Self { client: client.clone(), backend: Arc::new(OidcServer::new(client)) }
|
||||
}
|
||||
|
||||
fn ctx(&self) -> &OidcCtx {
|
||||
@@ -338,11 +336,6 @@ impl Oidc {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the `HttpService` to make requests with.
|
||||
fn http_service(&self) -> HttpService {
|
||||
HttpService::new(self.client.inner.http_client.clone())
|
||||
}
|
||||
|
||||
/// The OpenID Connect authentication data.
|
||||
///
|
||||
/// Returns `None` if the client registration was not restored with
|
||||
@@ -427,7 +420,7 @@ impl Oidc {
|
||||
&self,
|
||||
issuer: &str,
|
||||
) -> Result<VerifiedProviderMetadata, OidcError> {
|
||||
discover(&self.http_service(), issuer).await.map_err(Into::into)
|
||||
self.backend.discover(issuer).await
|
||||
}
|
||||
|
||||
/// Fetch the OpenID Connect metadata of the issuer.
|
||||
@@ -440,14 +433,6 @@ impl Oidc {
|
||||
self.given_provider_metadata(issuer).await
|
||||
}
|
||||
|
||||
/// Fetch the OpenID Connect JSON Web Key Set at the given URI.
|
||||
///
|
||||
/// Returns an error if the client registration was not restored, or if an
|
||||
/// error occurred when fetching the data.
|
||||
async fn fetch_jwks(&self, jwks_uri: &Url) -> Result<PublicJsonWebKeySet, OidcError> {
|
||||
Ok(fetch_jwks(&self.http_service(), jwks_uri).await?)
|
||||
}
|
||||
|
||||
/// The OpenID Connect metadata of this client used during registration.
|
||||
///
|
||||
/// Returns `None` if the client registration was not restored with
|
||||
@@ -654,20 +639,16 @@ impl Oidc {
|
||||
client_metadata: VerifiedClientMetadata,
|
||||
software_statement: Option<String>,
|
||||
) -> Result<ClientRegistrationResponse, OidcError> {
|
||||
let provider_metadata = self.given_provider_metadata(issuer).await?;
|
||||
let provider_metadata = self.backend.discover(issuer).await?;
|
||||
|
||||
let registration_endpoint = provider_metadata
|
||||
.registration_endpoint
|
||||
.as_ref()
|
||||
.ok_or(OidcError::NoRegistrationSupport)?;
|
||||
|
||||
register_client(
|
||||
&self.http_service(),
|
||||
registration_endpoint,
|
||||
client_metadata,
|
||||
software_statement,
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
self.backend
|
||||
.register_client(registration_endpoint, client_metadata, software_statement)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Set the data of a client that is registered with an OpenID Connect
|
||||
@@ -1007,44 +988,32 @@ impl Oidc {
|
||||
/// Returns an error if a request fails.
|
||||
pub async fn finish_authorization(
|
||||
&self,
|
||||
code: AuthorizationCode,
|
||||
) -> Result<AccessTokenResponse, OidcError> {
|
||||
let provider_metadata = self.provider_metadata().await?;
|
||||
auth_code: AuthorizationCode,
|
||||
) -> Result<(), OidcError> {
|
||||
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
|
||||
let validation_data = data
|
||||
.authorization_data
|
||||
.lock()
|
||||
.await
|
||||
.remove(&code.state)
|
||||
.remove(&auth_code.state)
|
||||
.ok_or(OidcError::InvalidState)?;
|
||||
|
||||
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
|
||||
let id_token_verification_data = JwtVerificationData {
|
||||
issuer: provider_metadata.issuer(),
|
||||
jwks: &jwks,
|
||||
client_id: &data.credentials.client_id().to_owned(),
|
||||
signing_algorithm: data.metadata.id_token_signed_response_alg(),
|
||||
};
|
||||
let provider_metadata = self.provider_metadata().await?;
|
||||
|
||||
let (response, id_token) = access_token_with_authorization_code(
|
||||
&self.http_service(),
|
||||
data.credentials.clone(),
|
||||
provider_metadata.token_endpoint(),
|
||||
code.code,
|
||||
validation_data,
|
||||
Some(id_token_verification_data),
|
||||
Utc::now(),
|
||||
&mut rng()?,
|
||||
)
|
||||
.await?;
|
||||
let session_tokens = self
|
||||
.backend
|
||||
.trade_authorization_code_for_tokens(
|
||||
provider_metadata,
|
||||
data.credentials.clone(),
|
||||
data.metadata.clone(),
|
||||
auth_code,
|
||||
validation_data,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.set_session_tokens(OidcSessionTokens {
|
||||
access_token: response.access_token.clone(),
|
||||
refresh_token: response.refresh_token.clone(),
|
||||
latest_id_token: id_token,
|
||||
});
|
||||
self.set_session_tokens(session_tokens);
|
||||
|
||||
Ok(response)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Abort the authorization process.
|
||||
@@ -1074,52 +1043,44 @@ impl Oidc {
|
||||
latest_id_token: Option<IdToken<'static>>,
|
||||
lock: Option<CrossProcessRefreshLockGuard>,
|
||||
) -> Result<(), OidcError> {
|
||||
let provider_metadata = self.provider_metadata().await?;
|
||||
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
|
||||
|
||||
let jwks = self.fetch_jwks(provider_metadata.jwks_uri()).await?;
|
||||
|
||||
trace!("Token refresh: attempting to refresh with refresh_token {}", hash(&refresh_token));
|
||||
|
||||
// Do not interrupt refresh access token requests and processing, by detaching
|
||||
// the request sending and response processing.
|
||||
|
||||
let provider_metadata = self.provider_metadata().await?;
|
||||
|
||||
let this = self.clone();
|
||||
let data = self.data().ok_or(OidcError::NotAuthenticated)?;
|
||||
let credentials = data.credentials.clone();
|
||||
let signing_algorithm = data.metadata.id_token_signed_response_alg().clone();
|
||||
let metadata = data.metadata.clone();
|
||||
|
||||
spawn(async move {
|
||||
let id_token_verification_data = JwtVerificationData {
|
||||
issuer: provider_metadata.issuer(),
|
||||
jwks: &jwks,
|
||||
client_id: &credentials.client_id().to_owned(),
|
||||
signing_algorithm: &signing_algorithm,
|
||||
};
|
||||
tracing::trace!(
|
||||
"Token refresh: attempting to refresh with refresh_token {}",
|
||||
hash(&refresh_token)
|
||||
);
|
||||
|
||||
match refresh_access_token(
|
||||
&this.http_service(),
|
||||
credentials,
|
||||
provider_metadata.token_endpoint(),
|
||||
refresh_token.clone(),
|
||||
None,
|
||||
Some(id_token_verification_data),
|
||||
latest_id_token.as_ref(),
|
||||
Utc::now(),
|
||||
&mut rng()?,
|
||||
)
|
||||
.await
|
||||
.map_err(OidcError::from)
|
||||
match this
|
||||
.backend
|
||||
.refresh_access_token(
|
||||
provider_metadata,
|
||||
credentials,
|
||||
&metadata,
|
||||
refresh_token.clone(),
|
||||
latest_id_token.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(OidcError::from)
|
||||
{
|
||||
Ok((response, _id_token)) => {
|
||||
Ok(new_tokens) => {
|
||||
trace!(
|
||||
"Token refresh: new refresh_token: {:?} / access_token: {}",
|
||||
response.refresh_token.as_ref().map(hash),
|
||||
hash(&response.access_token)
|
||||
new_tokens.refresh_token.as_ref().map(hash),
|
||||
hash(&new_tokens.access_token)
|
||||
);
|
||||
|
||||
let tokens = OidcSessionTokens {
|
||||
access_token: response.access_token,
|
||||
refresh_token: response.refresh_token.clone().or(Some(refresh_token)),
|
||||
access_token: new_tokens.access_token,
|
||||
refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)),
|
||||
latest_id_token,
|
||||
};
|
||||
|
||||
@@ -1260,32 +1221,27 @@ impl Oidc {
|
||||
provider_metadata.revocation_endpoint.as_ref().ok_or(OidcError::NoRevocationSupport)?;
|
||||
|
||||
let tokens = self.session_tokens().ok_or(OidcError::NotAuthenticated)?;
|
||||
let mut rng = rng()?;
|
||||
|
||||
// Revoke the access token.
|
||||
revoke_token(
|
||||
&self.http_service(),
|
||||
client_credentials.clone(),
|
||||
revocation_endpoint,
|
||||
tokens.access_token,
|
||||
Some(OAuthTokenTypeHint::AccessToken),
|
||||
Utc::now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
self.backend
|
||||
.revoke_token(
|
||||
client_credentials.clone(),
|
||||
revocation_endpoint,
|
||||
tokens.access_token,
|
||||
Some(OAuthTokenTypeHint::AccessToken),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Revoke the refresh token, if any.
|
||||
if let Some(refresh_token) = tokens.refresh_token {
|
||||
revoke_token(
|
||||
&self.http_service(),
|
||||
client_credentials.clone(),
|
||||
revocation_endpoint,
|
||||
refresh_token,
|
||||
Some(OAuthTokenTypeHint::RefreshToken),
|
||||
Utc::now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
self.backend
|
||||
.revoke_token(
|
||||
client_credentials.clone(),
|
||||
revocation_endpoint,
|
||||
refresh_token,
|
||||
Some(OAuthTokenTypeHint::RefreshToken),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let end_session_builder =
|
||||
@@ -1522,105 +1478,3 @@ fn hash<T: Hash>(x: &T) -> u64 {
|
||||
x.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use mas_oidc_client::types::{
|
||||
client_credentials::ClientCredentials, registration::ClientMetadata,
|
||||
};
|
||||
use matrix_sdk_base::SessionMeta;
|
||||
use matrix_sdk_test::async_test;
|
||||
use ruma::{
|
||||
api::client::discovery::discover_homeserver::AuthenticationServerInfo, OwnedUserId,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
oidc::{OidcAccountManagementAction, OidcSession, OidcSessionTokens, UserSession},
|
||||
ClientBuilder,
|
||||
};
|
||||
|
||||
#[async_test]
|
||||
async fn test_account_management_url() {
|
||||
let builder =
|
||||
ClientBuilder::new().homeserver_url(Url::parse("https://example.com").unwrap());
|
||||
let client = builder.build().await.unwrap();
|
||||
|
||||
client
|
||||
.restore_session(OidcSession {
|
||||
credentials: ClientCredentials::None { client_id: "client_id".to_owned() },
|
||||
metadata: ClientMetadata {
|
||||
redirect_uris: Some(vec![Url::parse("https://example.com/login").unwrap()]),
|
||||
..Default::default()
|
||||
}
|
||||
.validate()
|
||||
.unwrap(),
|
||||
user: UserSession {
|
||||
meta: SessionMeta {
|
||||
user_id: OwnedUserId::from_str("@user:example.com").unwrap(),
|
||||
device_id: "device_id".into(),
|
||||
},
|
||||
tokens: OidcSessionTokens {
|
||||
access_token: "access_token".to_owned(),
|
||||
refresh_token: None,
|
||||
latest_id_token: None,
|
||||
},
|
||||
issuer_info: AuthenticationServerInfo::new(
|
||||
"https://example.com".to_owned(),
|
||||
Some("https://example.com/account".to_owned()),
|
||||
),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
client.oidc().account_management_url(None).unwrap(),
|
||||
Some(Url::parse("https://example.com/account").unwrap())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::Profile))
|
||||
.unwrap(),
|
||||
Some(Url::parse("https://example.com/account?action=profile").unwrap())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::SessionsList))
|
||||
.unwrap(),
|
||||
Some(Url::parse("https://example.com/account?action=sessions_list").unwrap())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::SessionView {
|
||||
device_id: "my_phone".into()
|
||||
}))
|
||||
.unwrap(),
|
||||
Some(
|
||||
Url::parse("https://example.com/account?action=session_view&device_id=my_phone")
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::SessionEnd {
|
||||
device_id: "my_old_phone".into()
|
||||
}))
|
||||
.unwrap(),
|
||||
Some(
|
||||
Url::parse("https://example.com/account?action=session_end&device_id=my_old_phone")
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
299
crates/matrix-sdk/src/oidc/tests.rs
Normal file
299
crates/matrix-sdk/src/oidc/tests.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use assert_matches::assert_matches;
|
||||
use mas_oidc_client::{
|
||||
requests::authorization_code::AuthorizationValidationData,
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
errors::ClientErrorCode,
|
||||
iana::oauth::OAuthClientAuthenticationMethod,
|
||||
registration::{ClientMetadata, VerifiedClientMetadata},
|
||||
},
|
||||
};
|
||||
use matrix_sdk_base::SessionMeta;
|
||||
#[cfg(test)]
|
||||
use matrix_sdk_test::async_test;
|
||||
use ruma::{api::client::discovery::discover_homeserver::AuthenticationServerInfo, owned_user_id};
|
||||
use url::Url;
|
||||
|
||||
use super::{
|
||||
backend::mock::{MockImpl, AUTHORIZATION_URL, ISSUER_URL},
|
||||
AuthorizationCode, AuthorizationError, AuthorizationResponse, Oidc,
|
||||
OidcAccountManagementAction, OidcError, OidcSession, OidcSessionTokens,
|
||||
RedirectUriQueryParseError, UserSession,
|
||||
};
|
||||
use crate::test_utils::test_client_builder;
|
||||
|
||||
const CLIENT_ID: &str = "test_client_id";
|
||||
|
||||
pub fn mock_registered_client_data() -> (ClientCredentials, VerifiedClientMetadata) {
|
||||
(
|
||||
ClientCredentials::None { client_id: CLIENT_ID.to_owned() },
|
||||
ClientMetadata {
|
||||
redirect_uris: Some(vec![]), // empty vector is ok lol
|
||||
token_endpoint_auth_method: Some(OAuthClientAuthenticationMethod::None),
|
||||
..ClientMetadata::default()
|
||||
}
|
||||
.validate()
|
||||
.expect("validate client metadata"),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mock_session(tokens: OidcSessionTokens) -> OidcSession {
|
||||
let (credentials, metadata) = mock_registered_client_data();
|
||||
OidcSession {
|
||||
credentials,
|
||||
metadata,
|
||||
user: UserSession {
|
||||
meta: SessionMeta {
|
||||
user_id: ruma::user_id!("@u:e.uk").to_owned(),
|
||||
device_id: ruma::device_id!("XYZ").to_owned(),
|
||||
},
|
||||
tokens,
|
||||
issuer_info: AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_account_management_url() {
|
||||
let builder = test_client_builder(Some("https://example.com".to_owned()));
|
||||
let client = builder.build().await.unwrap();
|
||||
|
||||
client
|
||||
.restore_session(OidcSession {
|
||||
credentials: ClientCredentials::None { client_id: "client_id".to_owned() },
|
||||
|
||||
metadata: ClientMetadata {
|
||||
redirect_uris: Some(vec![Url::parse("https://example.com/login").unwrap()]),
|
||||
..Default::default()
|
||||
}
|
||||
.validate()
|
||||
.unwrap(),
|
||||
|
||||
user: UserSession {
|
||||
meta: SessionMeta {
|
||||
user_id: owned_user_id!("@user:example.com"),
|
||||
device_id: "device_id".into(),
|
||||
},
|
||||
tokens: OidcSessionTokens {
|
||||
access_token: "access_token".to_owned(),
|
||||
refresh_token: None,
|
||||
latest_id_token: None,
|
||||
},
|
||||
issuer_info: AuthenticationServerInfo::new(
|
||||
"https://example.com".to_owned(),
|
||||
Some("https://example.com/account".to_owned()),
|
||||
),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
client.oidc().account_management_url(None).unwrap(),
|
||||
Some(Url::parse("https://example.com/account").unwrap())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client.oidc().account_management_url(Some(OidcAccountManagementAction::Profile)).unwrap(),
|
||||
Some(Url::parse("https://example.com/account?action=profile").unwrap())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::SessionsList))
|
||||
.unwrap(),
|
||||
Some(Url::parse("https://example.com/account?action=sessions_list").unwrap())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::SessionView {
|
||||
device_id: "my_phone".into()
|
||||
}))
|
||||
.unwrap(),
|
||||
Some(
|
||||
Url::parse("https://example.com/account?action=session_view&device_id=my_phone")
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
client
|
||||
.oidc()
|
||||
.account_management_url(Some(OidcAccountManagementAction::SessionEnd {
|
||||
device_id: "my_old_phone".into()
|
||||
}))
|
||||
.unwrap(),
|
||||
Some(
|
||||
Url::parse("https://example.com/account?action=session_end&device_id=my_old_phone")
|
||||
.unwrap()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_login() -> anyhow::Result<()> {
|
||||
let client = test_client_builder(None).build().await?;
|
||||
|
||||
let device_id = "D3V1C31D".to_owned(); // yo this is 1999 speaking
|
||||
|
||||
let oidc = Oidc { client: client.clone(), backend: Arc::new(MockImpl::new()) };
|
||||
|
||||
let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None);
|
||||
let (client_credentials, client_metadata) = mock_registered_client_data();
|
||||
oidc.restore_registered_client(issuer_info, client_metadata, client_credentials);
|
||||
|
||||
let redirect_uri_str = "http://matrix.example.com/oidc/callback";
|
||||
let redirect_uri = Url::parse(redirect_uri_str)?;
|
||||
let mut authorization_data = oidc.login(redirect_uri, Some(device_id.clone()))?.build().await?;
|
||||
|
||||
tracing::debug!("authorization data URL = {}", authorization_data.url);
|
||||
|
||||
let mut num_expected = 6;
|
||||
let mut nonce = None;
|
||||
|
||||
for (key, val) in authorization_data.url.query_pairs() {
|
||||
match &*key {
|
||||
"response_type" => {
|
||||
assert_eq!(val, "code");
|
||||
num_expected -= 1;
|
||||
}
|
||||
"client_id" => {
|
||||
assert_eq!(val, CLIENT_ID);
|
||||
num_expected -= 1;
|
||||
}
|
||||
"redirect_uri" => {
|
||||
assert_eq!(val, redirect_uri_str);
|
||||
num_expected -= 1;
|
||||
}
|
||||
"scope" => {
|
||||
assert_eq!(val, format!("openid urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id}"));
|
||||
num_expected -= 1;
|
||||
}
|
||||
"state" => {
|
||||
num_expected -= 1;
|
||||
assert_eq!(val, authorization_data.state);
|
||||
}
|
||||
"nonce" => {
|
||||
num_expected -= 1;
|
||||
nonce = Some(val);
|
||||
}
|
||||
_ => panic!("unexpected query parameter: {key}={val}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(num_expected, 0);
|
||||
|
||||
let data = oidc.data().unwrap();
|
||||
let authorization_data_guard = data.authorization_data.lock().await;
|
||||
|
||||
let state = authorization_data_guard.get(&authorization_data.state).context("missing state")?;
|
||||
let nonce = nonce.context("missing nonce")?;
|
||||
assert_eq!(nonce, state.nonce);
|
||||
|
||||
authorization_data.url.set_query(None);
|
||||
assert_eq!(authorization_data.url, Url::parse(AUTHORIZATION_URL).unwrap(),);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_authorization_response() -> anyhow::Result<()> {
|
||||
let uri = Url::parse("https://example.com")?;
|
||||
assert_matches!(
|
||||
AuthorizationResponse::parse_uri(&uri),
|
||||
Err(RedirectUriQueryParseError::MissingQuery)
|
||||
);
|
||||
|
||||
let uri = Url::parse("https://example.com?code=123&state=456")?;
|
||||
assert_matches!(
|
||||
AuthorizationResponse::parse_uri(&uri),
|
||||
Ok(AuthorizationResponse::Success(AuthorizationCode { code, state })) => {
|
||||
assert_eq!(code, "123");
|
||||
assert_eq!(state, "456");
|
||||
}
|
||||
);
|
||||
|
||||
let uri = Url::parse("https://example.com?error=invalid_grant&state=456")?;
|
||||
assert_matches!(
|
||||
AuthorizationResponse::parse_uri(&uri),
|
||||
Ok(AuthorizationResponse::Error(AuthorizationError { error, state })) => {
|
||||
assert_eq!(error.error, ClientErrorCode::InvalidGrant);
|
||||
assert_eq!(error.error_description, None);
|
||||
assert_eq!(state, "456");
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_finish_authorization() -> anyhow::Result<()> {
|
||||
let client = test_client_builder(None).build().await?;
|
||||
|
||||
let session_tokens = OidcSessionTokens {
|
||||
access_token: "4cc3ss".to_owned(),
|
||||
refresh_token: Some("r3fr3$h".to_owned()),
|
||||
latest_id_token: None,
|
||||
};
|
||||
let oidc = Oidc {
|
||||
client: client.clone(),
|
||||
backend: Arc::new(MockImpl::new().next_session_tokens(session_tokens.clone())),
|
||||
};
|
||||
|
||||
let issuer_info = AuthenticationServerInfo::new(ISSUER_URL.to_owned(), None);
|
||||
let (client_credentials, client_metadata) = mock_registered_client_data();
|
||||
oidc.restore_registered_client(issuer_info, client_metadata, client_credentials);
|
||||
|
||||
// If the state is missing, then any attempt to finish authorizing will fail.
|
||||
let res = oidc
|
||||
.finish_authorization(AuthorizationCode { code: "42".to_owned(), state: "none".to_owned() })
|
||||
.await;
|
||||
|
||||
assert_matches!(res, Err(OidcError::InvalidState));
|
||||
assert!(oidc.session_tokens().is_none());
|
||||
|
||||
// Assuming a non-empty state "123"...
|
||||
let state = "state".to_owned();
|
||||
let redirect_uri = "http://matrix.example.com/oidc/callback";
|
||||
let auth_validation_data = AuthorizationValidationData {
|
||||
state: state.clone(),
|
||||
nonce: "nonce".to_owned(),
|
||||
redirect_uri: Url::parse(redirect_uri)?,
|
||||
code_challenge_verifier: None,
|
||||
};
|
||||
|
||||
{
|
||||
let data = oidc.data().context("missing data")?;
|
||||
let prev = data.authorization_data.lock().await.insert(state.clone(), {
|
||||
AuthorizationValidationData { ..auth_validation_data.clone() }
|
||||
});
|
||||
assert!(prev.is_none());
|
||||
}
|
||||
|
||||
// Finishing the authorization for another state won't work.
|
||||
let res = oidc
|
||||
.finish_authorization(AuthorizationCode {
|
||||
code: "1337".to_owned(),
|
||||
state: "none".to_owned(),
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_matches!(res, Err(OidcError::InvalidState));
|
||||
assert!(oidc.session_tokens().is_none());
|
||||
assert!(oidc.data().unwrap().authorization_data.lock().await.get(&state).is_some());
|
||||
|
||||
// Finishing the authorization for the expected state will work.
|
||||
oidc.finish_authorization(AuthorizationCode { code: "1337".to_owned(), state: state.clone() })
|
||||
.await?;
|
||||
|
||||
assert_eq!(oidc.session_tokens(), Some(session_tokens));
|
||||
assert!(oidc.data().unwrap().authorization_data.lock().await.get(&state).is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user