diff --git a/bindings/matrix-sdk-ffi/src/authentication.rs b/bindings/matrix-sdk-ffi/src/authentication.rs index 9517de9e1..6e15fd283 100644 --- a/bindings/matrix-sdk-ffi/src/authentication.rs +++ b/bindings/matrix-sdk-ffi/src/authentication.rs @@ -1,15 +1,14 @@ use std::{ collections::HashMap, fmt::{self, Debug}, - path::PathBuf, sync::Arc, }; use matrix_sdk::{ authentication::oauth::{ - error::{OAuthAuthorizationCodeError, OAuthRegistrationStoreError}, + error::OAuthAuthorizationCodeError, registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType}, - ClientId, OAuthError as SdkOAuthError, OAuthRegistrationStore, + ClientId, ClientRegistrationData, OAuthError as SdkOAuthError, }, Error, }; @@ -123,14 +122,12 @@ pub struct OidcConfiguration { /// An array of e-mail addresses of people responsible for this client. pub contacts: Option>, - /// Pre-configured registrations for use with issuers that don't support + /// Pre-configured registrations for use with homeservers that don't support /// dynamic client registration. - pub static_registrations: HashMap, - - /// A file path where any dynamic registrations should be stored. /// - /// Suggested value: `{base_path}/oidc/registrations.json` - pub dynamic_registrations_file: String, + /// The keys of the map should be the URLs of the homeservers, but keys + /// using `issuer` URLs are also supported. + pub static_registrations: HashMap, } impl OidcConfiguration { @@ -165,12 +162,10 @@ impl OidcConfiguration { Raw::new(&metadata).map_err(|_| OidcError::MetadataInvalid) } - pub async fn registrations(&self) -> Result { + pub(crate) fn registration_data(&self) -> Result { let client_metadata = self.client_metadata()?; - let registrations_file = PathBuf::from(&self.dynamic_registrations_file); - let mut registrations = - OAuthRegistrationStore::new(registrations_file, client_metadata).await?; + let mut registration_data = ClientRegistrationData::new(client_metadata); if !self.static_registrations.is_empty() { let static_registrations = self @@ -185,10 +180,10 @@ impl OidcConfiguration { }) .collect(); - registrations = registrations.with_static_registrations(static_registrations); + registration_data.static_registrations = Some(static_registrations); } - Ok(registrations) + Ok(registration_data) } } @@ -201,8 +196,6 @@ pub enum OidcError { NotSupported, #[error("Unable to use OIDC as the supplied client metadata is invalid.")] MetadataInvalid, - #[error("Failed to use the supplied registrations file path.")] - RegistrationsPathInvalid, #[error("The supplied callback URL used to complete OIDC is invalid.")] CallbackUrlInvalid, #[error("The OIDC login was cancelled by the user.")] @@ -228,17 +221,6 @@ impl From for OidcError { } } -impl From for OidcError { - fn from(e: OAuthRegistrationStoreError) -> OidcError { - match e { - OAuthRegistrationStoreError::NotAFilePath | OAuthRegistrationStoreError::File(_) => { - OidcError::RegistrationsPathInvalid - } - _ => OidcError::Generic { message: e.to_string() }, - } - } -} - impl From for OidcError { fn from(e: Error) -> OidcError { match e { diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index 247a9ffb4..783938c37 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -408,10 +408,10 @@ impl Client { oidc_configuration: &OidcConfiguration, prompt: Option, ) -> Result, OidcError> { - let registrations = oidc_configuration.registrations().await?; + let registration_data = oidc_configuration.registration_data()?; let redirect_uri = oidc_configuration.redirect_uri()?; - let mut url_builder = self.inner.oauth().login(registrations.into(), redirect_uri, None); + let mut url_builder = self.inner.oauth().login(redirect_uri, None, Some(registration_data)); if let Some(prompt) = prompt { url_builder = url_builder.prompt(vec![prompt.into()]); diff --git a/bindings/matrix-sdk-ffi/src/client_builder.rs b/bindings/matrix-sdk-ffi/src/client_builder.rs index 04da5d06c..4ae294248 100644 --- a/bindings/matrix-sdk-ffi/src/client_builder.rs +++ b/bindings/matrix-sdk-ffi/src/client_builder.rs @@ -755,13 +755,12 @@ impl ClientBuilder { } })?; - let registrations = oidc_configuration - .registrations() - .await + let registration_data = oidc_configuration + .registration_data() .map_err(|_| HumanQrLoginError::OidcMetadataInvalid)?; let oauth = client.inner.oauth(); - let login = oauth.login_with_qr_code(&qr_code_data.inner, registrations.into()); + let login = oauth.login_with_qr_code(&qr_code_data.inner, Some(®istration_data)); let mut progress = login.subscribe_to_progress(); diff --git a/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs b/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs index 8c88b47de..bd9a533a0 100644 --- a/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs +++ b/crates/matrix-sdk/src/authentication/oauth/auth_code_builder.rs @@ -24,7 +24,7 @@ use ruma::{ use tracing::{info, instrument}; use url::Url; -use super::{ClientRegistrationMethod, OAuth, OAuthError}; +use super::{ClientRegistrationData, OAuth, OAuthError}; use crate::{authentication::oauth::AuthorizationValidationData, Result}; /// Builder type used to configure optional settings for authorization with an @@ -34,7 +34,7 @@ use crate::{authentication::oauth::AuthorizationValidationData, Result}; #[allow(missing_debug_implementations)] pub struct OAuthAuthCodeUrlBuilder { oauth: OAuth, - registration_method: ClientRegistrationMethod, + registration_data: Option, scopes: Vec, device_id: OwnedDeviceId, redirect_uri: Url, @@ -45,14 +45,14 @@ pub struct OAuthAuthCodeUrlBuilder { impl OAuthAuthCodeUrlBuilder { pub(super) fn new( oauth: OAuth, - registration_method: ClientRegistrationMethod, scopes: Vec, device_id: OwnedDeviceId, redirect_uri: Url, + registration_data: Option, ) -> Self { Self { oauth, - registration_method, + registration_data, scopes, device_id, redirect_uri, @@ -93,19 +93,12 @@ impl OAuthAuthCodeUrlBuilder { /// request fails. #[instrument(target = "matrix_sdk::client", skip_all)] pub async fn build(self) -> Result { - let Self { - oauth, - registration_method, - scopes, - device_id, - redirect_uri, - prompt, - login_hint, - } = self; + let Self { oauth, registration_data, scopes, device_id, redirect_uri, prompt, login_hint } = + self; let server_metadata = oauth.server_metadata().await?; - oauth.use_registration_method(&server_metadata, ®istration_method).await?; + oauth.use_registration_data(&server_metadata, registration_data.as_ref()).await?; let data = oauth.data().expect("OAuth 2.0 data should be set after registration"); info!( diff --git a/crates/matrix-sdk/src/authentication/oauth/error.rs b/crates/matrix-sdk/src/authentication/oauth/error.rs index 1820e049f..80645d984 100644 --- a/crates/matrix-sdk/src/authentication/oauth/error.rs +++ b/crates/matrix-sdk/src/authentication/oauth/error.rs @@ -31,8 +31,6 @@ use ruma::{ #[cfg(feature = "e2e-encryption")] pub use super::cross_process::CrossProcessRefreshLockError; -#[cfg(not(target_arch = "wasm32"))] -pub use super::registration_store::OAuthRegistrationStoreError; /// An error when interacting with the OAuth 2.0 authorization server. pub type OAuthRequestError = @@ -258,11 +256,6 @@ pub enum OAuthClientRegistrationError { /// Deserialization of the registration response failed. #[error("failed to deserialize registration response: {0}")] FromJson(serde_json::Error), - - /// Failed to access or store the registration in the store. - #[cfg(not(target_arch = "wasm32"))] - #[error("failed to use registration store: {0}")] - Store(#[from] OAuthRegistrationStoreError), } /// Error response returned by server after requesting an authorization code. diff --git a/crates/matrix-sdk/src/authentication/oauth/mod.rs b/crates/matrix-sdk/src/authentication/oauth/mod.rs index 653183df5..b3f38d0c7 100644 --- a/crates/matrix-sdk/src/authentication/oauth/mod.rs +++ b/crates/matrix-sdk/src/authentication/oauth/mod.rs @@ -176,8 +176,6 @@ mod oidc_discovery; #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] pub mod qrcode; pub mod registration; -#[cfg(not(target_arch = "wasm32"))] -mod registration_store; #[cfg(all(test, not(target_arch = "wasm32")))] mod tests; @@ -185,8 +183,6 @@ mod tests; use self::cross_process::{CrossProcessRefreshLockGuard, CrossProcessRefreshManager}; #[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))] use self::qrcode::LoginWithQrCode; -#[cfg(not(target_arch = "wasm32"))] -pub use self::registration_store::OAuthRegistrationStore; pub use self::{ account_management_url::{AccountManagementActionFull, AccountManagementUrlBuilder}, auth_code_builder::{OAuthAuthCodeUrlBuilder, OAuthAuthorizationData}, @@ -351,6 +347,15 @@ impl OAuth { /// the private cross-signing keys and the backup key from the existing /// device to the new device. /// + /// # Arguments + /// + /// * `data` - The data scanned from a QR code. + /// + /// * `registration_data` - The data to restore or register the client with + /// the server. If this is not provided, an error will occur unless + /// [`OAuth::register_client()`] or [`OAuth::restore_registered_client()`] + /// was called previously. + /// /// # Example /// /// ```no_run @@ -384,11 +389,12 @@ impl OAuth { /// .await?; /// /// let oauth = client.oauth(); - /// let metadata: Raw = client_metadata(); + /// let client_metadata: Raw = client_metadata(); + /// let registration_data = client_metadata.into(); /// /// // Subscribing to the progress is necessary since we need to input the check /// // code on the existing device. - /// let login = oauth.login_with_qr_code(&qr_code_data, metadata.into()); + /// let login = oauth.login_with_qr_code(&qr_code_data, Some(®istration_data)); /// let mut progress = login.subscribe_to_progress(); /// /// // Create a task which will show us the progress and tell us the check @@ -420,73 +426,44 @@ impl OAuth { pub fn login_with_qr_code<'a>( &'a self, data: &'a QrCodeData, - registration_method: ClientRegistrationMethod, + registration_data: Option<&'a ClientRegistrationData>, ) -> LoginWithQrCode<'a> { - LoginWithQrCode::new(&self.client, registration_method, data) + LoginWithQrCode::new(&self.client, data, registration_data) } /// Restore or register the OAuth 2.0 client for the server with the given - /// metadata, with the given [`OAuthRegistrationStore`]. - /// - /// If there is a client ID in the store, it is used to restore the client. - /// Otherwise, the client is registered with the metadata in the store. - /// - /// Returns an error if there is an error while accessing the store, or - /// while registering the client. - #[cfg(not(target_arch = "wasm32"))] - async fn restore_or_register_client( - &self, - server_metadata: &AuthorizationServerMetadata, - registrations: &OAuthRegistrationStore, - ) -> std::result::Result<(), OAuthClientRegistrationError> { - if let Some(client_id) = registrations.client_id(&server_metadata.issuer).await? { - self.restore_registered_client(server_metadata.issuer.clone(), client_id); - - tracing::info!("OAuth 2.0 configuration loaded from disk."); - return Ok(()); - }; - - tracing::info!("Registering this client for OAuth 2.0."); - let response = self.register_client_inner(server_metadata, ®istrations.metadata).await?; - - tracing::info!("Persisting OAuth 2.0 registration data."); - registrations - .set_and_write_client_id(response.client_id, server_metadata.issuer.clone()) - .await?; - - Ok(()) - } - - /// Restore or register the OAuth 2.0 client for the server with the given - /// metadata, with the given [`ClientRegistrationMethod`]. + /// metadata, with the given optional [`ClientRegistrationData`]. /// /// If we already have a client ID, this is a noop. /// /// Returns an error if there was a problem using the registration method. - async fn use_registration_method( + async fn use_registration_data( &self, server_metadata: &AuthorizationServerMetadata, - method: &ClientRegistrationMethod, + data: Option<&ClientRegistrationData>, ) -> std::result::Result<(), OAuthError> { if self.client_id().is_some() { tracing::info!("OAuth 2.0 is already configured."); return Ok(()); }; - match method { - ClientRegistrationMethod::None => return Err(OAuthError::NotRegistered), - ClientRegistrationMethod::ClientId(client_id) => { + let Some(data) = data else { + return Err(OAuthError::NotRegistered); + }; + + if let Some(static_registrations) = &data.static_registrations { + let client_id = static_registrations + .get(&self.client.homeserver()) + .or_else(|| static_registrations.get(&server_metadata.issuer)); + + if let Some(client_id) = client_id { self.restore_registered_client(server_metadata.issuer.clone(), client_id.clone()); - } - ClientRegistrationMethod::Metadata(client_metadata) => { - self.register_client_inner(server_metadata, client_metadata).await?; - } - #[cfg(not(target_arch = "wasm32"))] - ClientRegistrationMethod::Store(registrations) => { - self.restore_or_register_client(server_metadata, registrations).await? + return Ok(()); } } + self.register_client_inner(server_metadata, &data.metadata).await?; + Ok(()) } @@ -923,18 +900,12 @@ impl OAuth { /// Login via OAuth 2.0 with the Authorization Code flow. /// - /// This should be called after [`OAuth::register_client()`] or - /// [`OAuth::restore_registered_client()`]. - /// /// [`OAuth::finish_login()`] must be called once the user has been /// redirected to the `redirect_uri`. [`OAuth::abort_login()`] should be /// called instead if the authorization should be aborted before completion. /// /// # Arguments /// - /// * `registration_method` - The method to restore or register the client - /// with the server. - /// /// * `redirect_uri` - The URI where the end user will be redirected after /// authorizing the login. It must be one of the redirect URIs sent in the /// client metadata during registration. @@ -944,16 +915,19 @@ impl OAuth { /// device ID from a previous login call. Note that this should be done /// only if the client also holds the corresponding encryption keys. /// + /// * `registration_data` - The data to restore or register the client with + /// the server. If this is not provided, an error will occur unless + /// [`OAuth::register_client()`] or [`OAuth::restore_registered_client()`] + /// was called previously. + /// /// # Example /// /// ```no_run - /// use anyhow::anyhow; /// use matrix_sdk::{ + /// authentication::oauth::registration::ClientMetadata, + /// ruma::serde::Raw, /// Client, - /// authentication::oauth::OAuthRegistrationStore, /// }; - /// # use ruma::serde::Raw; - /// # use matrix_sdk::authentication::oauth::registration::ClientMetadata; /// # let homeserver = unimplemented!(); /// # let redirect_uri = unimplemented!(); /// # let issuer_info = unimplemented!(); @@ -964,13 +938,10 @@ impl OAuth { /// # _ = async { /// # let client = Client::new(homeserver).await?; /// let oauth = client.oauth(); + /// let client_metadata: Raw = client_metadata(); + /// let registration_data = client_metadata.into(); /// - /// let registration_store = OAuthRegistrationStore::new( - /// store_path, - /// client_metadata() - /// ).await?; - /// - /// let auth_data = oauth.login(registration_store.into(), redirect_uri, None) + /// let auth_data = oauth.login(redirect_uri, None, Some(registration_data)) /// .build() /// .await?; /// @@ -980,7 +951,7 @@ impl OAuth { /// oauth.finish_login(redirected_to_uri.into()).await?; /// /// // The session tokens can be persisted from the - /// // `Client::session_tokens()` method. + /// // `OAuth::full_session()` method. /// /// // You can now make requests to the Matrix API. /// let _me = client.whoami().await?; @@ -988,18 +959,18 @@ impl OAuth { /// ``` pub fn login( &self, - registration_method: ClientRegistrationMethod, redirect_uri: Url, device_id: Option, + registration_data: Option, ) -> OAuthAuthCodeUrlBuilder { let (scopes, device_id) = Self::login_scopes(device_id); OAuthAuthCodeUrlBuilder::new( self.clone(), - registration_method, scopes.to_vec(), device_id, redirect_uri, + registration_data, ) } @@ -1547,47 +1518,32 @@ fn hash_str(x: &str) -> impl fmt::LowerHex { sha2::Sha256::new().chain_update(x).finalize() } -/// The available methods to register or restore a client. -#[derive(Debug)] -pub enum ClientRegistrationMethod { - /// No registration will be done. - /// - /// This should only be set if [`OAuth::register_client()`] or - /// [`OAuth::restore_registered_client()`] was already called before. - None, +/// Data to register or restore a client. +#[derive(Debug, Clone)] +pub struct ClientRegistrationData { + /// The metadata to use to register the client when using dynamic client + /// registration. + pub metadata: Raw, - /// The given client ID will be used. + /// Static registrations for servers that don't support dynamic registration + /// but provide a client ID out-of-band. /// - /// This will call [`OAuth::restore_registered_client()`] internally. - ClientId(ClientId), - - /// The client will register using dynamic client registration, with the - /// given metadata. - /// - /// This will call [`OAuth::register_client()`] internally. - Metadata(Raw), - - /// Use an [`OAuthRegistrationStore`] to handle registrations. - #[cfg(not(target_arch = "wasm32"))] - Store(OAuthRegistrationStore), + /// The keys of the map should be the URLs of the homeservers, but keys + /// using `issuer` URLs are also supported. + pub static_registrations: Option>, } -impl From for ClientRegistrationMethod { - fn from(value: ClientId) -> Self { - Self::ClientId(value) +impl ClientRegistrationData { + /// Construct a [`ClientRegistrationData`] with the given metadata and no + /// static registrations. + pub fn new(metadata: Raw) -> Self { + Self { metadata, static_registrations: None } } } -impl From> for ClientRegistrationMethod { +impl From> for ClientRegistrationData { fn from(value: Raw) -> Self { - Self::Metadata(value) - } -} - -#[cfg(not(target_arch = "wasm32"))] -impl From for ClientRegistrationMethod { - fn from(value: OAuthRegistrationStore) -> Self { - Self::Store(value) + Self::new(value) } } diff --git a/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs b/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs index fbda819a3..52e0c7352 100644 --- a/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs +++ b/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs @@ -38,7 +38,7 @@ use super::{ #[cfg(doc)] use crate::authentication::oauth::OAuth; use crate::{ - authentication::oauth::{ClientRegistrationMethod, OAuthError}, + authentication::oauth::{ClientRegistrationData, OAuthError}, Client, }; @@ -83,7 +83,7 @@ pub enum LoginProgress { #[derive(Debug)] pub struct LoginWithQrCode<'a> { client: &'a Client, - registration_method: ClientRegistrationMethod, + registration_data: Option<&'a ClientRegistrationData>, qr_code_data: &'a QrCodeData, state: SharedObservable, } @@ -270,10 +270,10 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> { impl<'a> LoginWithQrCode<'a> { pub(crate) fn new( client: &'a Client, - registration_method: ClientRegistrationMethod, qr_code_data: &'a QrCodeData, + registration_data: Option<&'a ClientRegistrationData>, ) -> LoginWithQrCode<'a> { - LoginWithQrCode { client, registration_method, qr_code_data, state: Default::default() } + LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() } } async fn establish_secure_channel( @@ -299,7 +299,7 @@ impl<'a> LoginWithQrCode<'a> { ) -> Result { let oauth = self.client.oauth(); let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?; - oauth.use_registration_method(&server_metadata, &self.registration_method).await?; + oauth.use_registration_data(&server_metadata, self.registration_data).await?; Ok(server_metadata) } @@ -465,7 +465,8 @@ mod test { let qr_code = alice.qr_code_data().clone(); let oauth = bob.oauth(); - let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into()); + let registration_data = mock_client_metadata().into(); + let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data)); let mut updates = login_bob.subscribe_to_progress(); let updates_task = tokio::spawn(async move { @@ -552,7 +553,8 @@ mod test { let qr_code = alice.qr_code_data().clone(); let oauth = bob.oauth(); - let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into()); + let registration_data = mock_client_metadata().into(); + let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data)); let mut updates = login_bob.subscribe_to_progress(); let _updates_task = tokio::spawn(async move { @@ -675,7 +677,8 @@ mod test { let qr_code = alice.qr_code_data().clone(); let oauth = bob.oauth(); - let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into()); + let registration_data = mock_client_metadata().into(); + let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data)); let mut updates = login_bob.subscribe_to_progress(); let _updates_task = tokio::spawn(async move { diff --git a/crates/matrix-sdk/src/authentication/oauth/registration_store.rs b/crates/matrix-sdk/src/authentication/oauth/registration_store.rs deleted file mode 100644 index 0b1cb6a52..000000000 --- a/crates/matrix-sdk/src/authentication/oauth/registration_store.rs +++ /dev/null @@ -1,317 +0,0 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! OAuth 2.0 client registration store. -//! -//! This module provides a way to persist OAuth 2.0 client registrations outside -//! of the state store. This is useful when using a `Client` with an in-memory -//! store or when different store paths are used for multi-account support -//! within the same app, and those accounts need to share the same OAuth 2.0 -//! client registration. - -use std::{collections::HashMap, io::ErrorKind, path::PathBuf}; - -use oauth2::ClientId; -use ruma::serde::Raw; -use serde::{Deserialize, Serialize}; -use tokio::fs; -use url::Url; - -use super::ClientMetadata; - -/// Errors that can occur when using the [`OAuthRegistrationStore`]. -#[derive(Debug, thiserror::Error)] -pub enum OAuthRegistrationStoreError { - /// The supplied path is not a file path. - #[error("supplied registrations path is not a file path")] - NotAFilePath, - /// An error occurred when reading from or writing to the file. - #[error(transparent)] - File(#[from] std::io::Error), - /// An error occurred when serializing the registration data. - #[error("failed to serialize registration data: {0}")] - IntoJson(serde_json::Error), - /// An error occurred when deserializing the registration data. - #[error("failed to deserialize registration data: {0}")] - FromJson(serde_json::Error), -} - -/// An API to store and restore OAuth 2.0 client registrations. -/// -/// This stores dynamic client registrations in a file, and accepts "static" -/// client registrations via -/// [`OAuthRegistrationStore::with_static_registrations()`], for servers that -/// don't support dynamic client registration. -/// -/// If the client metadata passed to this API changes, the previous -/// registrations that were stored in the file are invalidated, allowing to -/// re-register with the new metadata. -/// -/// The purpose of storing client IDs outside of the state store or separate -/// from the user's session is that it allows to reuse the same client ID -/// between user sessions on the same server. -#[derive(Debug)] -pub struct OAuthRegistrationStore { - /// The path of the file where the registrations are stored. - pub(super) file_path: PathBuf, - /// The metadata used to register the client. - /// This is used to check if the client needs to be re-registered. - pub(super) metadata: Raw, - /// Pre-configured registrations for use with issuers that don't support - /// dynamic client registration. - static_registrations: Option>, -} - -/// The underlying data serialized into the registration file. -#[derive(Debug, Serialize, Deserialize)] -struct FrozenRegistrationData { - /// The metadata used to register the client. - metadata: Raw, - /// All of the registrations this client has made as a HashMap of issuer URL - /// to client ID. - dynamic_registrations: HashMap, -} - -impl OAuthRegistrationStore { - /// Creates a new registration store. - /// - /// This method creates the `file`'s parent directory if it doesn't exist. - /// - /// # Arguments - /// - /// * `file` - A file path where the registrations will be stored. This - /// previously took a directory and stored the registrations with the path - /// `supplied_directory/oidc/registrations.json`. - /// - /// * `metadata` - The metadata used to register the client. If this changes - /// compared to the value stored in the file, any stored registrations - /// will be invalidated so the client can re-register with the new data. - pub async fn new( - file: PathBuf, - metadata: Raw, - ) -> Result { - let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?; - fs::create_dir_all(parent).await?; - - Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None }) - } - - /// Add static registrations to the store. - /// - /// Static registrations are used for servers that don't support dynamic - /// registration but provide a client ID out-of-band. - /// - /// These registrations are not stored in the file and must be provided each - /// time. - pub fn with_static_registrations( - mut self, - static_registrations: HashMap, - ) -> Self { - self.static_registrations = Some(static_registrations); - self - } - - /// Returns the client ID registered for a particular issuer or `None` if a - /// registration hasn't been made. - /// - /// # Arguments - /// - /// * `issuer` - The issuer to look up. - /// - /// # Errors - /// - /// Returns an error if the file could not be read, or if the data in the - /// file could not be deserialized. - pub async fn client_id( - &self, - issuer: &Url, - ) -> Result, OAuthRegistrationStoreError> { - if let Some(client_id) = - self.static_registrations.as_ref().and_then(|registrations| registrations.get(issuer)) - { - return Ok(Some(client_id.clone())); - } - - let data = self.read_registration_data().await?; - Ok(data.and_then(|mut data| data.dynamic_registrations.remove(issuer))) - } - - /// Stores a new client ID registration for a particular issuer. - /// - /// If a client ID has already been stored for the given issuer, this will - /// overwrite the old value. - /// - /// # Arguments - /// - /// * `client_id` - The client ID obtained after registration. - /// - /// * `issuer` - The issuer associated with the client ID. - /// - /// # Errors - /// - /// Returns an error if the file could not be read from or written to, or if - /// the data in the file could not be (de)serialized. - pub async fn set_and_write_client_id( - &self, - client_id: ClientId, - issuer: Url, - ) -> Result<(), OAuthRegistrationStoreError> { - let mut data = self.read_registration_data().await?.unwrap_or_else(|| { - tracing::info!("Generating new OAuth 2.0 client registration data"); - FrozenRegistrationData { - metadata: self.metadata.clone(), - dynamic_registrations: Default::default(), - } - }); - data.dynamic_registrations.insert(issuer, client_id); - - let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?; - fs::write(&self.file_path, contents).await?; - - Ok(()) - } - - /// The persisted registration data. - /// - /// # Errors - /// - /// Returns an error if the file could not be read, or if the data in the - /// file could not be deserialized. - async fn read_registration_data( - &self, - ) -> Result, OAuthRegistrationStoreError> { - let contents = match fs::read(&self.file_path).await { - Ok(contents) => contents, - Err(error) => { - if error.kind() == ErrorKind::NotFound { - // The file doesn't exist so there is no data. - return Ok(None); - } - - // Forward the error. - return Err(error.into()); - } - }; - - let registration_data: FrozenRegistrationData = - serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?; - - if registration_data.metadata.json().get() != self.metadata.json().get() { - tracing::info!("Metadata mismatch, ignoring any stored registrations."); - Ok(None) - } else { - Ok(Some(registration_data)) - } - } -} - -#[cfg(test)] -mod tests { - use matrix_sdk_test::async_test; - use tempfile::tempdir; - - use super::*; - use crate::authentication::oauth::registration::{ApplicationType, Localized, OAuthGrantType}; - - #[async_test] - async fn test_oauth_registration_store() { - // Given a fresh registration store with a single static registration. - let dir = tempdir().unwrap(); - let registrations_file = dir.path().join("oauth").join("registrations.json"); - - let static_url = Url::parse("https://example.com").unwrap(); - let static_id = ClientId::new("static_client_id".to_owned()); - let dynamic_url = Url::parse("https://example.org").unwrap(); - let dynamic_id = ClientId::new("dynamic_client_id".to_owned()); - - let mut static_registrations = HashMap::new(); - static_registrations.insert(static_url.clone(), static_id.clone()); - - let oidc_metadata = mock_metadata("Example".to_owned()); - - let registrations = OAuthRegistrationStore::new(registrations_file, oidc_metadata) - .await - .unwrap() - .with_static_registrations(static_registrations); - - assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone())); - assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None); - - // When a dynamic registration is added. - registrations - .set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()) - .await - .unwrap(); - - // Then the dynamic registration should be stored and the static registration - // should be unaffected. - assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id)); - assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id)); - } - - #[async_test] - async fn test_change_of_metadata() { - // Given a single registration with an example app name. - let dir = tempdir().unwrap(); - let registrations_file = dir.path().join("oidc").join("registrations.json"); - - let static_url = Url::parse("https://example.com").unwrap(); - let static_id = ClientId::new("static_client_id".to_owned()); - let dynamic_url = Url::parse("https://example.org").unwrap(); - let dynamic_id = ClientId::new("dynamic_client_id".to_owned()); - - let oidc_metadata = mock_metadata("Example".to_owned()); - - let mut static_registrations = HashMap::new(); - static_registrations.insert(static_url.clone(), static_id.clone()); - - let registrations = OAuthRegistrationStore::new(registrations_file.clone(), oidc_metadata) - .await - .unwrap() - .with_static_registrations(static_registrations.clone()); - registrations - .set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()) - .await - .unwrap(); - - assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone())); - assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id)); - - // When the app name changes. - let new_oidc_metadata = mock_metadata("New App".to_owned()); - - let registrations = OAuthRegistrationStore::new(registrations_file, new_oidc_metadata) - .await - .unwrap() - .with_static_registrations(static_registrations); - - // Then the dynamic registrations are cleared. - assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None); - assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id)); - } - - fn mock_metadata(client_name: String) -> Raw { - let callback_url = Url::parse("https://example.org/login/callback").unwrap(); - let client_uri = Url::parse("https://example.org/").unwrap(); - - let mut metadata = ClientMetadata::new( - ApplicationType::Web, - vec![OAuthGrantType::AuthorizationCode { redirect_uris: vec![callback_url] }], - Localized::new(client_uri, None), - ); - metadata.client_name = Some(Localized::new(client_name, None)); - - Raw::new(&metadata).unwrap() - } -} diff --git a/crates/matrix-sdk/src/authentication/oauth/tests.rs b/crates/matrix-sdk/src/authentication/oauth/tests.rs index e4491fae3..102755aac 100644 --- a/crates/matrix-sdk/src/authentication/oauth/tests.rs +++ b/crates/matrix-sdk/src/authentication/oauth/tests.rs @@ -9,7 +9,6 @@ use ruma::{ owned_device_id, user_id, DeviceId, ServerName, }; use serde_json::json; -use tempfile::tempdir; use tokio::sync::broadcast::error::TryRecvError; use url::Url; use wiremock::{ @@ -18,14 +17,13 @@ use wiremock::{ }; use super::{ - registration_store::OAuthRegistrationStore, AuthorizationCode, AuthorizationError, - AuthorizationResponse, OAuth, OAuthAuthorizationData, OAuthError, RedirectUriQueryParseError, - UrlOrQuery, + AuthorizationCode, AuthorizationError, AuthorizationResponse, OAuth, OAuthAuthorizationData, + OAuthError, RedirectUriQueryParseError, UrlOrQuery, }; use crate::{ authentication::oauth::{ error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError}, - AuthorizationValidationData, ClientRegistrationMethod, OAuthAuthorizationCodeError, + AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError, }, test_utils::{ client::{ @@ -40,7 +38,7 @@ use crate::{ const REDIRECT_URI_STRING: &str = "http://127.0.0.1:6778/oauth/callback"; -async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAuthRegistrationStore)> +async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, ClientRegistrationData)> { let server = MatrixMockServer::new().await; server.mock_who_am_i().ok().named("whoami").mount().await; @@ -53,12 +51,7 @@ async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAu let client = server.client_builder().unlogged().build().await; let client_metadata = mock_client_metadata(); - let registrations_path = - tempdir().unwrap().path().join("matrix-sdk-oauth").join("registrations.json"); - let registrations = - OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap(); - - Ok((client.oauth(), server, mock_redirect_uri(), registrations)) + Ok((client.oauth(), server, mock_redirect_uri(), client_metadata.into())) } /// Check the URL in the given authorization data. @@ -154,35 +147,23 @@ async fn check_authorization_url( #[async_test] async fn test_high_level_login() -> anyhow::Result<()> { // Given a fresh environment. - let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap(); - let registrations_path = registrations.file_path.clone(); - let client_metadata = registrations.metadata.clone(); + let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap(); assert!(oauth.issuer().is_none()); assert!(oauth.client_id().is_none()); // When getting the OIDC login URL. let authorization_data = oauth - .login(registrations.into(), redirect_uri.clone(), None) + .login(redirect_uri.clone(), None, Some(registration_data)) .prompt(vec![Prompt::Create]) .build() .await .unwrap(); // Then the client should be configured correctly. - assert_let!(Some(issuer) = oauth.issuer()); + assert!(oauth.issuer().is_some()); assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); - // The client ID should have been saved in the registration file. - let registrations = - OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap(); - assert_eq!( - registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()), - Some("test_client_id") - ); - - check_authorization_url(&authorization_data, &oauth, issuer, None, Some("create"), None).await; - // When completing the login with a valid callback. redirect_uri.set_query(Some(&format!("code=42&state={}", authorization_data.state.secret()))); @@ -195,24 +176,14 @@ async fn test_high_level_login() -> anyhow::Result<()> { #[async_test] async fn test_high_level_login_cancellation() -> anyhow::Result<()> { // Given a client ready to complete login. - let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap(); - let registrations_path = registrations.file_path.clone(); - let client_metadata = registrations.metadata.clone(); + let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap(); let authorization_data = - oauth.login(registrations.into(), redirect_uri.clone(), None).build().await.unwrap(); + oauth.login(redirect_uri.clone(), None, Some(registration_data)).build().await.unwrap(); assert_let!(Some(issuer) = oauth.issuer()); assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); - // The client ID should have been saved in the registration file. - let registrations = - OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap(); - assert_eq!( - registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()), - Some("test_client_id") - ); - check_authorization_url(&authorization_data, &oauth, issuer, None, None, None).await; // When completing login with a cancellation callback. @@ -235,24 +206,14 @@ async fn test_high_level_login_cancellation() -> anyhow::Result<()> { #[async_test] async fn test_high_level_login_invalid_state() -> anyhow::Result<()> { // Given a client ready to complete login. - let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap(); - let registrations_path = registrations.file_path.clone(); - let client_metadata = registrations.metadata.clone(); + let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap(); let authorization_data = - oauth.login(registrations.into(), redirect_uri.clone(), None).build().await.unwrap(); + oauth.login(redirect_uri.clone(), None, Some(registration_data)).build().await.unwrap(); assert_let!(Some(issuer) = oauth.issuer()); assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); - // The client ID should have been saved in the registration file. - let registrations = - OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap(); - assert_eq!( - registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()), - Some("test_client_id") - ); - check_authorization_url(&authorization_data, &oauth, issuer, None, None, None).await; // When completing login with an old/tampered state. @@ -286,16 +247,14 @@ async fn test_login_url() -> anyhow::Result<()> { let redirect_uri = Url::parse(redirect_uri_str)?; // No extra parameters. - let authorization_data = oauth - .login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone())) - .build() - .await?; + let authorization_data = + oauth.login(redirect_uri.clone(), Some(device_id.clone()), None).build().await?; check_authorization_url(&authorization_data, &oauth, &issuer, Some(&device_id), None, None) .await; // With prompt parameter. let authorization_data = oauth - .login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone())) + .login(redirect_uri.clone(), Some(device_id.clone()), None) .prompt(vec![Prompt::Create]) .build() .await?; @@ -311,7 +270,7 @@ async fn test_login_url() -> anyhow::Result<()> { // With user_id_hint parameter. let authorization_data = oauth - .login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone())) + .login(redirect_uri.clone(), Some(device_id.clone()), None) .user_id_hint(user_id!("@joe:example.org")) .build() .await?; @@ -719,7 +678,7 @@ async fn test_server_metadata() { } #[async_test] -async fn test_client_registration_methods() { +async fn test_client_registration_data() { let server = MatrixMockServer::new().await; let oauth_server = server.oauth(); let server_metadata = oauth_server.server_metadata(); @@ -727,29 +686,26 @@ async fn test_client_registration_methods() { // Without registration we get an error. let client = server.client_builder().unlogged().build().await; let oauth = client.oauth(); - let res = - oauth.use_registration_method(&server_metadata, &ClientRegistrationMethod::None).await; + let res = oauth.use_registration_data(&server_metadata, None).await; assert_matches!(res, Err(OAuthError::NotRegistered)); assert_eq!(oauth.client_id(), None); - // With a client ID. - oauth - .use_registration_method( - &server_metadata, - &ClientRegistrationMethod::ClientId(mock_client_id()), - ) - .await - .unwrap(); + // With a static registration. + let registration_data = ClientRegistrationData { + metadata: mock_client_metadata(), + static_registrations: Some([(server_metadata.issuer.clone(), mock_client_id())].into()), + }; + oauth.use_registration_data(&server_metadata, Some(®istration_data)).await.unwrap(); assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); - // If we call it again we get the same client ID. - oauth - .use_registration_method( - &server_metadata, - &ClientRegistrationMethod::ClientId(ClientId::new("other_client_id".to_owned())), - ) - .await - .unwrap(); + // If we call it again, it's a noop. + let registration_data = ClientRegistrationData { + metadata: mock_client_metadata(), + static_registrations: Some( + [(server_metadata.issuer.clone(), ClientId::new("other_client_id".to_owned()))].into(), + ), + }; + oauth.use_registration_data(&server_metadata, Some(®istration_data)).await.unwrap(); assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); // With metadata we register a new client ID. @@ -765,40 +721,6 @@ async fn test_client_registration_methods() { .mount() .await; - oauth - .use_registration_method( - &server_metadata, - &ClientRegistrationMethod::Metadata(client_metadata.clone()), - ) - .await - .unwrap(); - assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); - - // With registration store we register a new client ID. - let client = server.client_builder().unlogged().build().await; - let oauth = client.oauth(); - - let registrations_path = - tempdir().unwrap().path().join("matrix-sdk-oauth").join("registrations.json"); - let registrations = - OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap(); - let store_method = ClientRegistrationMethod::Store(registrations); - - oauth_server - .mock_registration() - .ok() - .mock_once() - .named("registration_with_store") - .mount() - .await; - - oauth.use_registration_method(&server_metadata, &store_method).await.unwrap(); - assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); - - // With same registration store, the client ID is loaded from the store. - let client = server.client_builder().unlogged().build().await; - let oauth = client.oauth(); - - oauth.use_registration_method(&server_metadata, &store_method).await.unwrap(); + oauth.use_registration_data(&server_metadata, Some(&client_metadata.into())).await.unwrap(); assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id")); } diff --git a/examples/oauth_cli/src/main.rs b/examples/oauth_cli/src/main.rs index f17b39ef4..355ae97a8 100644 --- a/examples/oauth_cli/src/main.rs +++ b/examples/oauth_cli/src/main.rs @@ -25,8 +25,8 @@ use matrix_sdk::{ authentication::oauth::{ error::OAuthClientRegistrationError, registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType}, - AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError, - OAuthRegistrationStore, OAuthSession, UrlOrQuery, UserSession, + AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError, OAuthSession, + UrlOrQuery, UserSession, }, config::SyncSettings, encryption::{recovery::RecoveryState, CrossSigningResetAuthType}, @@ -139,7 +139,7 @@ impl OAuthCli { let (client, client_session) = build_client(data_dir).await?; let cli = Self { client, restored: false, session_file }; - if let Err(error) = cli.register_and_login(data_dir).await { + if let Err(error) = cli.register_and_login().await { if error.downcast_ref::().is_some_and(|error| { matches!( error, @@ -179,28 +179,18 @@ impl OAuthCli { /// Register the client and log in the user via the OAuth 2.0 Authorization /// Code flow. - async fn register_and_login(&self, data_dir: &Path) -> anyhow::Result<()> { + async fn register_and_login(&self) -> anyhow::Result<()> { let oauth = self.client.oauth(); // We create a loop here so the user can retry if an error happens. loop { - // The registrations store allows to store the client ID separately - // and reuse it between sessions, so we don't need to - // re-register each time. - // We put it in the parent directory because we remove the data directory on - // logout. - let registration_file = - data_dir.parent().expect("directory has a parent").join("oauth_registrations"); - let registration_store = - OAuthRegistrationStore::new(registration_file, client_metadata()).await?; - // Here we spawn a server to listen on the loopback interface. Another option // would be to register a custom URI scheme with the system and handle // the redirect when the custom URI scheme is opened. let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?; let OAuthAuthorizationData { url, .. } = - oauth.login(registration_store.into(), redirect_uri, None).build().await?; + oauth.login(redirect_uri, None, Some(client_metadata().into())).build().await?; let query_string = use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default(); diff --git a/examples/qr-login/src/main.rs b/examples/qr-login/src/main.rs index 8eac9a347..43c6c69c5 100644 --- a/examples/qr-login/src/main.rs +++ b/examples/qr-login/src/main.rs @@ -112,10 +112,10 @@ async fn login(proxy: Option) -> Result<()> { let client = client.build().await?; - let metadata = client_metadata(); + let registration_data = client_metadata().into(); let oauth = client.oauth(); - let login_client = oauth.login_with_qr_code(&data, metadata.into()); + let login_client = oauth.login_with_qr_code(&data, Some(®istration_data)); let mut subscriber = login_client.subscribe_to_progress(); let task = tokio::spawn(async move {