refactor(oauth): Remove OAuthRegistrationStore

MSC2966 was updated, clients should re-register for every log in, so we
don't need to store the client IDs between logins.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-04-02 13:18:42 +02:00
committed by Stefan Ceriu
parent c4d9ec98c3
commit 8883e081af
11 changed files with 134 additions and 613 deletions

View File

@@ -1,15 +1,14 @@
use std::{
collections::HashMap,
fmt::{self, Debug},
path::PathBuf,
sync::Arc,
};
use matrix_sdk::{
authentication::oauth::{
error::{OAuthAuthorizationCodeError, OAuthRegistrationStoreError},
error::OAuthAuthorizationCodeError,
registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
ClientId, OAuthError as SdkOAuthError, OAuthRegistrationStore,
ClientId, ClientRegistrationData, OAuthError as SdkOAuthError,
},
Error,
};
@@ -123,14 +122,12 @@ pub struct OidcConfiguration {
/// An array of e-mail addresses of people responsible for this client.
pub contacts: Option<Vec<String>>,
/// Pre-configured registrations for use with issuers that don't support
/// Pre-configured registrations for use with homeservers that don't support
/// dynamic client registration.
pub static_registrations: HashMap<String, String>,
/// A file path where any dynamic registrations should be stored.
///
/// Suggested value: `{base_path}/oidc/registrations.json`
pub dynamic_registrations_file: String,
/// The keys of the map should be the URLs of the homeservers, but keys
/// using `issuer` URLs are also supported.
pub static_registrations: HashMap<String, String>,
}
impl OidcConfiguration {
@@ -165,12 +162,10 @@ impl OidcConfiguration {
Raw::new(&metadata).map_err(|_| OidcError::MetadataInvalid)
}
pub async fn registrations(&self) -> Result<OAuthRegistrationStore, OidcError> {
pub(crate) fn registration_data(&self) -> Result<ClientRegistrationData, OidcError> {
let client_metadata = self.client_metadata()?;
let registrations_file = PathBuf::from(&self.dynamic_registrations_file);
let mut registrations =
OAuthRegistrationStore::new(registrations_file, client_metadata).await?;
let mut registration_data = ClientRegistrationData::new(client_metadata);
if !self.static_registrations.is_empty() {
let static_registrations = self
@@ -185,10 +180,10 @@ impl OidcConfiguration {
})
.collect();
registrations = registrations.with_static_registrations(static_registrations);
registration_data.static_registrations = Some(static_registrations);
}
Ok(registrations)
Ok(registration_data)
}
}
@@ -201,8 +196,6 @@ pub enum OidcError {
NotSupported,
#[error("Unable to use OIDC as the supplied client metadata is invalid.")]
MetadataInvalid,
#[error("Failed to use the supplied registrations file path.")]
RegistrationsPathInvalid,
#[error("The supplied callback URL used to complete OIDC is invalid.")]
CallbackUrlInvalid,
#[error("The OIDC login was cancelled by the user.")]
@@ -228,17 +221,6 @@ impl From<SdkOAuthError> for OidcError {
}
}
impl From<OAuthRegistrationStoreError> for OidcError {
fn from(e: OAuthRegistrationStoreError) -> OidcError {
match e {
OAuthRegistrationStoreError::NotAFilePath | OAuthRegistrationStoreError::File(_) => {
OidcError::RegistrationsPathInvalid
}
_ => OidcError::Generic { message: e.to_string() },
}
}
}
impl From<Error> for OidcError {
fn from(e: Error) -> OidcError {
match e {

View File

@@ -408,10 +408,10 @@ impl Client {
oidc_configuration: &OidcConfiguration,
prompt: Option<OidcPrompt>,
) -> Result<Arc<OAuthAuthorizationData>, OidcError> {
let registrations = oidc_configuration.registrations().await?;
let registration_data = oidc_configuration.registration_data()?;
let redirect_uri = oidc_configuration.redirect_uri()?;
let mut url_builder = self.inner.oauth().login(registrations.into(), redirect_uri, None);
let mut url_builder = self.inner.oauth().login(redirect_uri, None, Some(registration_data));
if let Some(prompt) = prompt {
url_builder = url_builder.prompt(vec![prompt.into()]);

View File

@@ -755,13 +755,12 @@ impl ClientBuilder {
}
})?;
let registrations = oidc_configuration
.registrations()
.await
let registration_data = oidc_configuration
.registration_data()
.map_err(|_| HumanQrLoginError::OidcMetadataInvalid)?;
let oauth = client.inner.oauth();
let login = oauth.login_with_qr_code(&qr_code_data.inner, registrations.into());
let login = oauth.login_with_qr_code(&qr_code_data.inner, Some(&registration_data));
let mut progress = login.subscribe_to_progress();

View File

@@ -24,7 +24,7 @@ use ruma::{
use tracing::{info, instrument};
use url::Url;
use super::{ClientRegistrationMethod, OAuth, OAuthError};
use super::{ClientRegistrationData, OAuth, OAuthError};
use crate::{authentication::oauth::AuthorizationValidationData, Result};
/// Builder type used to configure optional settings for authorization with an
@@ -34,7 +34,7 @@ use crate::{authentication::oauth::AuthorizationValidationData, Result};
#[allow(missing_debug_implementations)]
pub struct OAuthAuthCodeUrlBuilder {
oauth: OAuth,
registration_method: ClientRegistrationMethod,
registration_data: Option<ClientRegistrationData>,
scopes: Vec<Scope>,
device_id: OwnedDeviceId,
redirect_uri: Url,
@@ -45,14 +45,14 @@ pub struct OAuthAuthCodeUrlBuilder {
impl OAuthAuthCodeUrlBuilder {
pub(super) fn new(
oauth: OAuth,
registration_method: ClientRegistrationMethod,
scopes: Vec<Scope>,
device_id: OwnedDeviceId,
redirect_uri: Url,
registration_data: Option<ClientRegistrationData>,
) -> Self {
Self {
oauth,
registration_method,
registration_data,
scopes,
device_id,
redirect_uri,
@@ -93,19 +93,12 @@ impl OAuthAuthCodeUrlBuilder {
/// request fails.
#[instrument(target = "matrix_sdk::client", skip_all)]
pub async fn build(self) -> Result<OAuthAuthorizationData, OAuthError> {
let Self {
oauth,
registration_method,
scopes,
device_id,
redirect_uri,
prompt,
login_hint,
} = self;
let Self { oauth, registration_data, scopes, device_id, redirect_uri, prompt, login_hint } =
self;
let server_metadata = oauth.server_metadata().await?;
oauth.use_registration_method(&server_metadata, &registration_method).await?;
oauth.use_registration_data(&server_metadata, registration_data.as_ref()).await?;
let data = oauth.data().expect("OAuth 2.0 data should be set after registration");
info!(

View File

@@ -31,8 +31,6 @@ use ruma::{
#[cfg(feature = "e2e-encryption")]
pub use super::cross_process::CrossProcessRefreshLockError;
#[cfg(not(target_arch = "wasm32"))]
pub use super::registration_store::OAuthRegistrationStoreError;
/// An error when interacting with the OAuth 2.0 authorization server.
pub type OAuthRequestError<T> =
@@ -258,11 +256,6 @@ pub enum OAuthClientRegistrationError {
/// Deserialization of the registration response failed.
#[error("failed to deserialize registration response: {0}")]
FromJson(serde_json::Error),
/// Failed to access or store the registration in the store.
#[cfg(not(target_arch = "wasm32"))]
#[error("failed to use registration store: {0}")]
Store(#[from] OAuthRegistrationStoreError),
}
/// Error response returned by server after requesting an authorization code.

View File

@@ -176,8 +176,6 @@ mod oidc_discovery;
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
pub mod qrcode;
pub mod registration;
#[cfg(not(target_arch = "wasm32"))]
mod registration_store;
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests;
@@ -185,8 +183,6 @@ mod tests;
use self::cross_process::{CrossProcessRefreshLockGuard, CrossProcessRefreshManager};
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
use self::qrcode::LoginWithQrCode;
#[cfg(not(target_arch = "wasm32"))]
pub use self::registration_store::OAuthRegistrationStore;
pub use self::{
account_management_url::{AccountManagementActionFull, AccountManagementUrlBuilder},
auth_code_builder::{OAuthAuthCodeUrlBuilder, OAuthAuthorizationData},
@@ -351,6 +347,15 @@ impl OAuth {
/// the private cross-signing keys and the backup key from the existing
/// device to the new device.
///
/// # Arguments
///
/// * `data` - The data scanned from a QR code.
///
/// * `registration_data` - The data to restore or register the client with
/// the server. If this is not provided, an error will occur unless
/// [`OAuth::register_client()`] or [`OAuth::restore_registered_client()`]
/// was called previously.
///
/// # Example
///
/// ```no_run
@@ -384,11 +389,12 @@ impl OAuth {
/// .await?;
///
/// let oauth = client.oauth();
/// let metadata: Raw<ClientMetadata> = client_metadata();
/// let client_metadata: Raw<ClientMetadata> = client_metadata();
/// let registration_data = client_metadata.into();
///
/// // Subscribing to the progress is necessary since we need to input the check
/// // code on the existing device.
/// let login = oauth.login_with_qr_code(&qr_code_data, metadata.into());
/// let login = oauth.login_with_qr_code(&qr_code_data, Some(&registration_data));
/// let mut progress = login.subscribe_to_progress();
///
/// // Create a task which will show us the progress and tell us the check
@@ -420,73 +426,44 @@ impl OAuth {
pub fn login_with_qr_code<'a>(
&'a self,
data: &'a QrCodeData,
registration_method: ClientRegistrationMethod,
registration_data: Option<&'a ClientRegistrationData>,
) -> LoginWithQrCode<'a> {
LoginWithQrCode::new(&self.client, registration_method, data)
LoginWithQrCode::new(&self.client, data, registration_data)
}
/// Restore or register the OAuth 2.0 client for the server with the given
/// metadata, with the given [`OAuthRegistrationStore`].
///
/// If there is a client ID in the store, it is used to restore the client.
/// Otherwise, the client is registered with the metadata in the store.
///
/// Returns an error if there is an error while accessing the store, or
/// while registering the client.
#[cfg(not(target_arch = "wasm32"))]
async fn restore_or_register_client(
&self,
server_metadata: &AuthorizationServerMetadata,
registrations: &OAuthRegistrationStore,
) -> std::result::Result<(), OAuthClientRegistrationError> {
if let Some(client_id) = registrations.client_id(&server_metadata.issuer).await? {
self.restore_registered_client(server_metadata.issuer.clone(), client_id);
tracing::info!("OAuth 2.0 configuration loaded from disk.");
return Ok(());
};
tracing::info!("Registering this client for OAuth 2.0.");
let response = self.register_client_inner(server_metadata, &registrations.metadata).await?;
tracing::info!("Persisting OAuth 2.0 registration data.");
registrations
.set_and_write_client_id(response.client_id, server_metadata.issuer.clone())
.await?;
Ok(())
}
/// Restore or register the OAuth 2.0 client for the server with the given
/// metadata, with the given [`ClientRegistrationMethod`].
/// metadata, with the given optional [`ClientRegistrationData`].
///
/// If we already have a client ID, this is a noop.
///
/// Returns an error if there was a problem using the registration method.
async fn use_registration_method(
async fn use_registration_data(
&self,
server_metadata: &AuthorizationServerMetadata,
method: &ClientRegistrationMethod,
data: Option<&ClientRegistrationData>,
) -> std::result::Result<(), OAuthError> {
if self.client_id().is_some() {
tracing::info!("OAuth 2.0 is already configured.");
return Ok(());
};
match method {
ClientRegistrationMethod::None => return Err(OAuthError::NotRegistered),
ClientRegistrationMethod::ClientId(client_id) => {
let Some(data) = data else {
return Err(OAuthError::NotRegistered);
};
if let Some(static_registrations) = &data.static_registrations {
let client_id = static_registrations
.get(&self.client.homeserver())
.or_else(|| static_registrations.get(&server_metadata.issuer));
if let Some(client_id) = client_id {
self.restore_registered_client(server_metadata.issuer.clone(), client_id.clone());
}
ClientRegistrationMethod::Metadata(client_metadata) => {
self.register_client_inner(server_metadata, client_metadata).await?;
}
#[cfg(not(target_arch = "wasm32"))]
ClientRegistrationMethod::Store(registrations) => {
self.restore_or_register_client(server_metadata, registrations).await?
return Ok(());
}
}
self.register_client_inner(server_metadata, &data.metadata).await?;
Ok(())
}
@@ -923,18 +900,12 @@ impl OAuth {
/// Login via OAuth 2.0 with the Authorization Code flow.
///
/// This should be called after [`OAuth::register_client()`] or
/// [`OAuth::restore_registered_client()`].
///
/// [`OAuth::finish_login()`] must be called once the user has been
/// redirected to the `redirect_uri`. [`OAuth::abort_login()`] should be
/// called instead if the authorization should be aborted before completion.
///
/// # Arguments
///
/// * `registration_method` - The method to restore or register the client
/// with the server.
///
/// * `redirect_uri` - The URI where the end user will be redirected after
/// authorizing the login. It must be one of the redirect URIs sent in the
/// client metadata during registration.
@@ -944,16 +915,19 @@ impl OAuth {
/// device ID from a previous login call. Note that this should be done
/// only if the client also holds the corresponding encryption keys.
///
/// * `registration_data` - The data to restore or register the client with
/// the server. If this is not provided, an error will occur unless
/// [`OAuth::register_client()`] or [`OAuth::restore_registered_client()`]
/// was called previously.
///
/// # Example
///
/// ```no_run
/// use anyhow::anyhow;
/// use matrix_sdk::{
/// authentication::oauth::registration::ClientMetadata,
/// ruma::serde::Raw,
/// Client,
/// authentication::oauth::OAuthRegistrationStore,
/// };
/// # use ruma::serde::Raw;
/// # use matrix_sdk::authentication::oauth::registration::ClientMetadata;
/// # let homeserver = unimplemented!();
/// # let redirect_uri = unimplemented!();
/// # let issuer_info = unimplemented!();
@@ -964,13 +938,10 @@ impl OAuth {
/// # _ = async {
/// # let client = Client::new(homeserver).await?;
/// let oauth = client.oauth();
/// let client_metadata: Raw<ClientMetadata> = client_metadata();
/// let registration_data = client_metadata.into();
///
/// let registration_store = OAuthRegistrationStore::new(
/// store_path,
/// client_metadata()
/// ).await?;
///
/// let auth_data = oauth.login(registration_store.into(), redirect_uri, None)
/// let auth_data = oauth.login(redirect_uri, None, Some(registration_data))
/// .build()
/// .await?;
///
@@ -980,7 +951,7 @@ impl OAuth {
/// oauth.finish_login(redirected_to_uri.into()).await?;
///
/// // The session tokens can be persisted from the
/// // `Client::session_tokens()` method.
/// // `OAuth::full_session()` method.
///
/// // You can now make requests to the Matrix API.
/// let _me = client.whoami().await?;
@@ -988,18 +959,18 @@ impl OAuth {
/// ```
pub fn login(
&self,
registration_method: ClientRegistrationMethod,
redirect_uri: Url,
device_id: Option<OwnedDeviceId>,
registration_data: Option<ClientRegistrationData>,
) -> OAuthAuthCodeUrlBuilder {
let (scopes, device_id) = Self::login_scopes(device_id);
OAuthAuthCodeUrlBuilder::new(
self.clone(),
registration_method,
scopes.to_vec(),
device_id,
redirect_uri,
registration_data,
)
}
@@ -1547,47 +1518,32 @@ fn hash_str(x: &str) -> impl fmt::LowerHex {
sha2::Sha256::new().chain_update(x).finalize()
}
/// The available methods to register or restore a client.
#[derive(Debug)]
pub enum ClientRegistrationMethod {
/// No registration will be done.
///
/// This should only be set if [`OAuth::register_client()`] or
/// [`OAuth::restore_registered_client()`] was already called before.
None,
/// Data to register or restore a client.
#[derive(Debug, Clone)]
pub struct ClientRegistrationData {
/// The metadata to use to register the client when using dynamic client
/// registration.
pub metadata: Raw<ClientMetadata>,
/// The given client ID will be used.
/// Static registrations for servers that don't support dynamic registration
/// but provide a client ID out-of-band.
///
/// This will call [`OAuth::restore_registered_client()`] internally.
ClientId(ClientId),
/// The client will register using dynamic client registration, with the
/// given metadata.
///
/// This will call [`OAuth::register_client()`] internally.
Metadata(Raw<ClientMetadata>),
/// Use an [`OAuthRegistrationStore`] to handle registrations.
#[cfg(not(target_arch = "wasm32"))]
Store(OAuthRegistrationStore),
/// The keys of the map should be the URLs of the homeservers, but keys
/// using `issuer` URLs are also supported.
pub static_registrations: Option<HashMap<Url, ClientId>>,
}
impl From<ClientId> for ClientRegistrationMethod {
fn from(value: ClientId) -> Self {
Self::ClientId(value)
impl ClientRegistrationData {
/// Construct a [`ClientRegistrationData`] with the given metadata and no
/// static registrations.
pub fn new(metadata: Raw<ClientMetadata>) -> Self {
Self { metadata, static_registrations: None }
}
}
impl From<Raw<ClientMetadata>> for ClientRegistrationMethod {
impl From<Raw<ClientMetadata>> for ClientRegistrationData {
fn from(value: Raw<ClientMetadata>) -> Self {
Self::Metadata(value)
}
}
#[cfg(not(target_arch = "wasm32"))]
impl From<OAuthRegistrationStore> for ClientRegistrationMethod {
fn from(value: OAuthRegistrationStore) -> Self {
Self::Store(value)
Self::new(value)
}
}

View File

@@ -38,7 +38,7 @@ use super::{
#[cfg(doc)]
use crate::authentication::oauth::OAuth;
use crate::{
authentication::oauth::{ClientRegistrationMethod, OAuthError},
authentication::oauth::{ClientRegistrationData, OAuthError},
Client,
};
@@ -83,7 +83,7 @@ pub enum LoginProgress {
#[derive(Debug)]
pub struct LoginWithQrCode<'a> {
client: &'a Client,
registration_method: ClientRegistrationMethod,
registration_data: Option<&'a ClientRegistrationData>,
qr_code_data: &'a QrCodeData,
state: SharedObservable<LoginProgress>,
}
@@ -270,10 +270,10 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
impl<'a> LoginWithQrCode<'a> {
pub(crate) fn new(
client: &'a Client,
registration_method: ClientRegistrationMethod,
qr_code_data: &'a QrCodeData,
registration_data: Option<&'a ClientRegistrationData>,
) -> LoginWithQrCode<'a> {
LoginWithQrCode { client, registration_method, qr_code_data, state: Default::default() }
LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
}
async fn establish_secure_channel(
@@ -299,7 +299,7 @@ impl<'a> LoginWithQrCode<'a> {
) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
let oauth = self.client.oauth();
let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
oauth.use_registration_method(&server_metadata, &self.registration_method).await?;
oauth.use_registration_data(&server_metadata, self.registration_data).await?;
Ok(server_metadata)
}
@@ -465,7 +465,8 @@ mod test {
let qr_code = alice.qr_code_data().clone();
let oauth = bob.oauth();
let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into());
let registration_data = mock_client_metadata().into();
let login_bob = oauth.login_with_qr_code(&qr_code, Some(&registration_data));
let mut updates = login_bob.subscribe_to_progress();
let updates_task = tokio::spawn(async move {
@@ -552,7 +553,8 @@ mod test {
let qr_code = alice.qr_code_data().clone();
let oauth = bob.oauth();
let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into());
let registration_data = mock_client_metadata().into();
let login_bob = oauth.login_with_qr_code(&qr_code, Some(&registration_data));
let mut updates = login_bob.subscribe_to_progress();
let _updates_task = tokio::spawn(async move {
@@ -675,7 +677,8 @@ mod test {
let qr_code = alice.qr_code_data().clone();
let oauth = bob.oauth();
let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into());
let registration_data = mock_client_metadata().into();
let login_bob = oauth.login_with_qr_code(&qr_code, Some(&registration_data));
let mut updates = login_bob.subscribe_to_progress();
let _updates_task = tokio::spawn(async move {

View File

@@ -1,317 +0,0 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! OAuth 2.0 client registration store.
//!
//! This module provides a way to persist OAuth 2.0 client registrations outside
//! of the state store. This is useful when using a `Client` with an in-memory
//! store or when different store paths are used for multi-account support
//! within the same app, and those accounts need to share the same OAuth 2.0
//! client registration.
use std::{collections::HashMap, io::ErrorKind, path::PathBuf};
use oauth2::ClientId;
use ruma::serde::Raw;
use serde::{Deserialize, Serialize};
use tokio::fs;
use url::Url;
use super::ClientMetadata;
/// Errors that can occur when using the [`OAuthRegistrationStore`].
#[derive(Debug, thiserror::Error)]
pub enum OAuthRegistrationStoreError {
/// The supplied path is not a file path.
#[error("supplied registrations path is not a file path")]
NotAFilePath,
/// An error occurred when reading from or writing to the file.
#[error(transparent)]
File(#[from] std::io::Error),
/// An error occurred when serializing the registration data.
#[error("failed to serialize registration data: {0}")]
IntoJson(serde_json::Error),
/// An error occurred when deserializing the registration data.
#[error("failed to deserialize registration data: {0}")]
FromJson(serde_json::Error),
}
/// An API to store and restore OAuth 2.0 client registrations.
///
/// This stores dynamic client registrations in a file, and accepts "static"
/// client registrations via
/// [`OAuthRegistrationStore::with_static_registrations()`], for servers that
/// don't support dynamic client registration.
///
/// If the client metadata passed to this API changes, the previous
/// registrations that were stored in the file are invalidated, allowing to
/// re-register with the new metadata.
///
/// The purpose of storing client IDs outside of the state store or separate
/// from the user's session is that it allows to reuse the same client ID
/// between user sessions on the same server.
#[derive(Debug)]
pub struct OAuthRegistrationStore {
/// The path of the file where the registrations are stored.
pub(super) file_path: PathBuf,
/// The metadata used to register the client.
/// This is used to check if the client needs to be re-registered.
pub(super) metadata: Raw<ClientMetadata>,
/// Pre-configured registrations for use with issuers that don't support
/// dynamic client registration.
static_registrations: Option<HashMap<Url, ClientId>>,
}
/// The underlying data serialized into the registration file.
#[derive(Debug, Serialize, Deserialize)]
struct FrozenRegistrationData {
/// The metadata used to register the client.
metadata: Raw<ClientMetadata>,
/// All of the registrations this client has made as a HashMap of issuer URL
/// to client ID.
dynamic_registrations: HashMap<Url, ClientId>,
}
impl OAuthRegistrationStore {
/// Creates a new registration store.
///
/// This method creates the `file`'s parent directory if it doesn't exist.
///
/// # Arguments
///
/// * `file` - A file path where the registrations will be stored. This
/// previously took a directory and stored the registrations with the path
/// `supplied_directory/oidc/registrations.json`.
///
/// * `metadata` - The metadata used to register the client. If this changes
/// compared to the value stored in the file, any stored registrations
/// will be invalidated so the client can re-register with the new data.
pub async fn new(
file: PathBuf,
metadata: Raw<ClientMetadata>,
) -> Result<Self, OAuthRegistrationStoreError> {
let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
fs::create_dir_all(parent).await?;
Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
}
/// Add static registrations to the store.
///
/// Static registrations are used for servers that don't support dynamic
/// registration but provide a client ID out-of-band.
///
/// These registrations are not stored in the file and must be provided each
/// time.
pub fn with_static_registrations(
mut self,
static_registrations: HashMap<Url, ClientId>,
) -> Self {
self.static_registrations = Some(static_registrations);
self
}
/// Returns the client ID registered for a particular issuer or `None` if a
/// registration hasn't been made.
///
/// # Arguments
///
/// * `issuer` - The issuer to look up.
///
/// # Errors
///
/// Returns an error if the file could not be read, or if the data in the
/// file could not be deserialized.
pub async fn client_id(
&self,
issuer: &Url,
) -> Result<Option<ClientId>, OAuthRegistrationStoreError> {
if let Some(client_id) =
self.static_registrations.as_ref().and_then(|registrations| registrations.get(issuer))
{
return Ok(Some(client_id.clone()));
}
let data = self.read_registration_data().await?;
Ok(data.and_then(|mut data| data.dynamic_registrations.remove(issuer)))
}
/// Stores a new client ID registration for a particular issuer.
///
/// If a client ID has already been stored for the given issuer, this will
/// overwrite the old value.
///
/// # Arguments
///
/// * `client_id` - The client ID obtained after registration.
///
/// * `issuer` - The issuer associated with the client ID.
///
/// # Errors
///
/// Returns an error if the file could not be read from or written to, or if
/// the data in the file could not be (de)serialized.
pub async fn set_and_write_client_id(
&self,
client_id: ClientId,
issuer: Url,
) -> Result<(), OAuthRegistrationStoreError> {
let mut data = self.read_registration_data().await?.unwrap_or_else(|| {
tracing::info!("Generating new OAuth 2.0 client registration data");
FrozenRegistrationData {
metadata: self.metadata.clone(),
dynamic_registrations: Default::default(),
}
});
data.dynamic_registrations.insert(issuer, client_id);
let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?;
fs::write(&self.file_path, contents).await?;
Ok(())
}
/// The persisted registration data.
///
/// # Errors
///
/// Returns an error if the file could not be read, or if the data in the
/// file could not be deserialized.
async fn read_registration_data(
&self,
) -> Result<Option<FrozenRegistrationData>, OAuthRegistrationStoreError> {
let contents = match fs::read(&self.file_path).await {
Ok(contents) => contents,
Err(error) => {
if error.kind() == ErrorKind::NotFound {
// The file doesn't exist so there is no data.
return Ok(None);
}
// Forward the error.
return Err(error.into());
}
};
let registration_data: FrozenRegistrationData =
serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?;
if registration_data.metadata.json().get() != self.metadata.json().get() {
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
Ok(None)
} else {
Ok(Some(registration_data))
}
}
}
#[cfg(test)]
mod tests {
use matrix_sdk_test::async_test;
use tempfile::tempdir;
use super::*;
use crate::authentication::oauth::registration::{ApplicationType, Localized, OAuthGrantType};
#[async_test]
async fn test_oauth_registration_store() {
// Given a fresh registration store with a single static registration.
let dir = tempdir().unwrap();
let registrations_file = dir.path().join("oauth").join("registrations.json");
let static_url = Url::parse("https://example.com").unwrap();
let static_id = ClientId::new("static_client_id".to_owned());
let dynamic_url = Url::parse("https://example.org").unwrap();
let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
let mut static_registrations = HashMap::new();
static_registrations.insert(static_url.clone(), static_id.clone());
let oidc_metadata = mock_metadata("Example".to_owned());
let registrations = OAuthRegistrationStore::new(registrations_file, oidc_metadata)
.await
.unwrap()
.with_static_registrations(static_registrations);
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None);
// When a dynamic registration is added.
registrations
.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone())
.await
.unwrap();
// Then the dynamic registration should be stored and the static registration
// should be unaffected.
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id));
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id));
}
#[async_test]
async fn test_change_of_metadata() {
// Given a single registration with an example app name.
let dir = tempdir().unwrap();
let registrations_file = dir.path().join("oidc").join("registrations.json");
let static_url = Url::parse("https://example.com").unwrap();
let static_id = ClientId::new("static_client_id".to_owned());
let dynamic_url = Url::parse("https://example.org").unwrap();
let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
let oidc_metadata = mock_metadata("Example".to_owned());
let mut static_registrations = HashMap::new();
static_registrations.insert(static_url.clone(), static_id.clone());
let registrations = OAuthRegistrationStore::new(registrations_file.clone(), oidc_metadata)
.await
.unwrap()
.with_static_registrations(static_registrations.clone());
registrations
.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone())
.await
.unwrap();
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id));
// When the app name changes.
let new_oidc_metadata = mock_metadata("New App".to_owned());
let registrations = OAuthRegistrationStore::new(registrations_file, new_oidc_metadata)
.await
.unwrap()
.with_static_registrations(static_registrations);
// Then the dynamic registrations are cleared.
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None);
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id));
}
fn mock_metadata(client_name: String) -> Raw<ClientMetadata> {
let callback_url = Url::parse("https://example.org/login/callback").unwrap();
let client_uri = Url::parse("https://example.org/").unwrap();
let mut metadata = ClientMetadata::new(
ApplicationType::Web,
vec![OAuthGrantType::AuthorizationCode { redirect_uris: vec![callback_url] }],
Localized::new(client_uri, None),
);
metadata.client_name = Some(Localized::new(client_name, None));
Raw::new(&metadata).unwrap()
}
}

View File

@@ -9,7 +9,6 @@ use ruma::{
owned_device_id, user_id, DeviceId, ServerName,
};
use serde_json::json;
use tempfile::tempdir;
use tokio::sync::broadcast::error::TryRecvError;
use url::Url;
use wiremock::{
@@ -18,14 +17,13 @@ use wiremock::{
};
use super::{
registration_store::OAuthRegistrationStore, AuthorizationCode, AuthorizationError,
AuthorizationResponse, OAuth, OAuthAuthorizationData, OAuthError, RedirectUriQueryParseError,
UrlOrQuery,
AuthorizationCode, AuthorizationError, AuthorizationResponse, OAuth, OAuthAuthorizationData,
OAuthError, RedirectUriQueryParseError, UrlOrQuery,
};
use crate::{
authentication::oauth::{
error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError},
AuthorizationValidationData, ClientRegistrationMethod, OAuthAuthorizationCodeError,
AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError,
},
test_utils::{
client::{
@@ -40,7 +38,7 @@ use crate::{
const REDIRECT_URI_STRING: &str = "http://127.0.0.1:6778/oauth/callback";
async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAuthRegistrationStore)>
async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, ClientRegistrationData)>
{
let server = MatrixMockServer::new().await;
server.mock_who_am_i().ok().named("whoami").mount().await;
@@ -53,12 +51,7 @@ async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAu
let client = server.client_builder().unlogged().build().await;
let client_metadata = mock_client_metadata();
let registrations_path =
tempdir().unwrap().path().join("matrix-sdk-oauth").join("registrations.json");
let registrations =
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
Ok((client.oauth(), server, mock_redirect_uri(), registrations))
Ok((client.oauth(), server, mock_redirect_uri(), client_metadata.into()))
}
/// Check the URL in the given authorization data.
@@ -154,35 +147,23 @@ async fn check_authorization_url(
#[async_test]
async fn test_high_level_login() -> anyhow::Result<()> {
// Given a fresh environment.
let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap();
let registrations_path = registrations.file_path.clone();
let client_metadata = registrations.metadata.clone();
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
assert!(oauth.issuer().is_none());
assert!(oauth.client_id().is_none());
// When getting the OIDC login URL.
let authorization_data = oauth
.login(registrations.into(), redirect_uri.clone(), None)
.login(redirect_uri.clone(), None, Some(registration_data))
.prompt(vec![Prompt::Create])
.build()
.await
.unwrap();
// Then the client should be configured correctly.
assert_let!(Some(issuer) = oauth.issuer());
assert!(oauth.issuer().is_some());
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// The client ID should have been saved in the registration file.
let registrations =
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
assert_eq!(
registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()),
Some("test_client_id")
);
check_authorization_url(&authorization_data, &oauth, issuer, None, Some("create"), None).await;
// When completing the login with a valid callback.
redirect_uri.set_query(Some(&format!("code=42&state={}", authorization_data.state.secret())));
@@ -195,24 +176,14 @@ async fn test_high_level_login() -> anyhow::Result<()> {
#[async_test]
async fn test_high_level_login_cancellation() -> anyhow::Result<()> {
// Given a client ready to complete login.
let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap();
let registrations_path = registrations.file_path.clone();
let client_metadata = registrations.metadata.clone();
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
let authorization_data =
oauth.login(registrations.into(), redirect_uri.clone(), None).build().await.unwrap();
oauth.login(redirect_uri.clone(), None, Some(registration_data)).build().await.unwrap();
assert_let!(Some(issuer) = oauth.issuer());
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// The client ID should have been saved in the registration file.
let registrations =
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
assert_eq!(
registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()),
Some("test_client_id")
);
check_authorization_url(&authorization_data, &oauth, issuer, None, None, None).await;
// When completing login with a cancellation callback.
@@ -235,24 +206,14 @@ async fn test_high_level_login_cancellation() -> anyhow::Result<()> {
#[async_test]
async fn test_high_level_login_invalid_state() -> anyhow::Result<()> {
// Given a client ready to complete login.
let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap();
let registrations_path = registrations.file_path.clone();
let client_metadata = registrations.metadata.clone();
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
let authorization_data =
oauth.login(registrations.into(), redirect_uri.clone(), None).build().await.unwrap();
oauth.login(redirect_uri.clone(), None, Some(registration_data)).build().await.unwrap();
assert_let!(Some(issuer) = oauth.issuer());
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// The client ID should have been saved in the registration file.
let registrations =
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
assert_eq!(
registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()),
Some("test_client_id")
);
check_authorization_url(&authorization_data, &oauth, issuer, None, None, None).await;
// When completing login with an old/tampered state.
@@ -286,16 +247,14 @@ async fn test_login_url() -> anyhow::Result<()> {
let redirect_uri = Url::parse(redirect_uri_str)?;
// No extra parameters.
let authorization_data = oauth
.login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone()))
.build()
.await?;
let authorization_data =
oauth.login(redirect_uri.clone(), Some(device_id.clone()), None).build().await?;
check_authorization_url(&authorization_data, &oauth, &issuer, Some(&device_id), None, None)
.await;
// With prompt parameter.
let authorization_data = oauth
.login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone()))
.login(redirect_uri.clone(), Some(device_id.clone()), None)
.prompt(vec![Prompt::Create])
.build()
.await?;
@@ -311,7 +270,7 @@ async fn test_login_url() -> anyhow::Result<()> {
// With user_id_hint parameter.
let authorization_data = oauth
.login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone()))
.login(redirect_uri.clone(), Some(device_id.clone()), None)
.user_id_hint(user_id!("@joe:example.org"))
.build()
.await?;
@@ -719,7 +678,7 @@ async fn test_server_metadata() {
}
#[async_test]
async fn test_client_registration_methods() {
async fn test_client_registration_data() {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
let server_metadata = oauth_server.server_metadata();
@@ -727,29 +686,26 @@ async fn test_client_registration_methods() {
// Without registration we get an error.
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
let res =
oauth.use_registration_method(&server_metadata, &ClientRegistrationMethod::None).await;
let res = oauth.use_registration_data(&server_metadata, None).await;
assert_matches!(res, Err(OAuthError::NotRegistered));
assert_eq!(oauth.client_id(), None);
// With a client ID.
oauth
.use_registration_method(
&server_metadata,
&ClientRegistrationMethod::ClientId(mock_client_id()),
)
.await
.unwrap();
// With a static registration.
let registration_data = ClientRegistrationData {
metadata: mock_client_metadata(),
static_registrations: Some([(server_metadata.issuer.clone(), mock_client_id())].into()),
};
oauth.use_registration_data(&server_metadata, Some(&registration_data)).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// If we call it again we get the same client ID.
oauth
.use_registration_method(
&server_metadata,
&ClientRegistrationMethod::ClientId(ClientId::new("other_client_id".to_owned())),
)
.await
.unwrap();
// If we call it again, it's a noop.
let registration_data = ClientRegistrationData {
metadata: mock_client_metadata(),
static_registrations: Some(
[(server_metadata.issuer.clone(), ClientId::new("other_client_id".to_owned()))].into(),
),
};
oauth.use_registration_data(&server_metadata, Some(&registration_data)).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// With metadata we register a new client ID.
@@ -765,40 +721,6 @@ async fn test_client_registration_methods() {
.mount()
.await;
oauth
.use_registration_method(
&server_metadata,
&ClientRegistrationMethod::Metadata(client_metadata.clone()),
)
.await
.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// With registration store we register a new client ID.
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
let registrations_path =
tempdir().unwrap().path().join("matrix-sdk-oauth").join("registrations.json");
let registrations =
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
let store_method = ClientRegistrationMethod::Store(registrations);
oauth_server
.mock_registration()
.ok()
.mock_once()
.named("registration_with_store")
.mount()
.await;
oauth.use_registration_method(&server_metadata, &store_method).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
// With same registration store, the client ID is loaded from the store.
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
oauth.use_registration_method(&server_metadata, &store_method).await.unwrap();
oauth.use_registration_data(&server_metadata, Some(&client_metadata.into())).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
}

View File

@@ -25,8 +25,8 @@ use matrix_sdk::{
authentication::oauth::{
error::OAuthClientRegistrationError,
registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError,
OAuthRegistrationStore, OAuthSession, UrlOrQuery, UserSession,
AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError, OAuthSession,
UrlOrQuery, UserSession,
},
config::SyncSettings,
encryption::{recovery::RecoveryState, CrossSigningResetAuthType},
@@ -139,7 +139,7 @@ impl OAuthCli {
let (client, client_session) = build_client(data_dir).await?;
let cli = Self { client, restored: false, session_file };
if let Err(error) = cli.register_and_login(data_dir).await {
if let Err(error) = cli.register_and_login().await {
if error.downcast_ref::<OAuthError>().is_some_and(|error| {
matches!(
error,
@@ -179,28 +179,18 @@ impl OAuthCli {
/// Register the client and log in the user via the OAuth 2.0 Authorization
/// Code flow.
async fn register_and_login(&self, data_dir: &Path) -> anyhow::Result<()> {
async fn register_and_login(&self) -> anyhow::Result<()> {
let oauth = self.client.oauth();
// We create a loop here so the user can retry if an error happens.
loop {
// The registrations store allows to store the client ID separately
// and reuse it between sessions, so we don't need to
// re-register each time.
// We put it in the parent directory because we remove the data directory on
// logout.
let registration_file =
data_dir.parent().expect("directory has a parent").join("oauth_registrations");
let registration_store =
OAuthRegistrationStore::new(registration_file, client_metadata()).await?;
// Here we spawn a server to listen on the loopback interface. Another option
// would be to register a custom URI scheme with the system and handle
// the redirect when the custom URI scheme is opened.
let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
let OAuthAuthorizationData { url, .. } =
oauth.login(registration_store.into(), redirect_uri, None).build().await?;
oauth.login(redirect_uri, None, Some(client_metadata().into())).build().await?;
let query_string =
use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default();

View File

@@ -112,10 +112,10 @@ async fn login(proxy: Option<Url>) -> Result<()> {
let client = client.build().await?;
let metadata = client_metadata();
let registration_data = client_metadata().into();
let oauth = client.oauth();
let login_client = oauth.login_with_qr_code(&data, metadata.into());
let login_client = oauth.login_with_qr_code(&data, Some(&registration_data));
let mut subscriber = login_client.subscribe_to_progress();
let task = tokio::spawn(async move {