mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-04 14:04:40 -04:00
refactor(oauth): Remove OAuthRegistrationStore
MSC2966 was updated, clients should re-register for every log in, so we don't need to store the client IDs between logins. Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
committed by
Stefan Ceriu
parent
c4d9ec98c3
commit
8883e081af
@@ -1,15 +1,14 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt::{self, Debug},
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use matrix_sdk::{
|
||||
authentication::oauth::{
|
||||
error::{OAuthAuthorizationCodeError, OAuthRegistrationStoreError},
|
||||
error::OAuthAuthorizationCodeError,
|
||||
registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
|
||||
ClientId, OAuthError as SdkOAuthError, OAuthRegistrationStore,
|
||||
ClientId, ClientRegistrationData, OAuthError as SdkOAuthError,
|
||||
},
|
||||
Error,
|
||||
};
|
||||
@@ -123,14 +122,12 @@ pub struct OidcConfiguration {
|
||||
/// An array of e-mail addresses of people responsible for this client.
|
||||
pub contacts: Option<Vec<String>>,
|
||||
|
||||
/// Pre-configured registrations for use with issuers that don't support
|
||||
/// Pre-configured registrations for use with homeservers that don't support
|
||||
/// dynamic client registration.
|
||||
pub static_registrations: HashMap<String, String>,
|
||||
|
||||
/// A file path where any dynamic registrations should be stored.
|
||||
///
|
||||
/// Suggested value: `{base_path}/oidc/registrations.json`
|
||||
pub dynamic_registrations_file: String,
|
||||
/// The keys of the map should be the URLs of the homeservers, but keys
|
||||
/// using `issuer` URLs are also supported.
|
||||
pub static_registrations: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl OidcConfiguration {
|
||||
@@ -165,12 +162,10 @@ impl OidcConfiguration {
|
||||
Raw::new(&metadata).map_err(|_| OidcError::MetadataInvalid)
|
||||
}
|
||||
|
||||
pub async fn registrations(&self) -> Result<OAuthRegistrationStore, OidcError> {
|
||||
pub(crate) fn registration_data(&self) -> Result<ClientRegistrationData, OidcError> {
|
||||
let client_metadata = self.client_metadata()?;
|
||||
|
||||
let registrations_file = PathBuf::from(&self.dynamic_registrations_file);
|
||||
let mut registrations =
|
||||
OAuthRegistrationStore::new(registrations_file, client_metadata).await?;
|
||||
let mut registration_data = ClientRegistrationData::new(client_metadata);
|
||||
|
||||
if !self.static_registrations.is_empty() {
|
||||
let static_registrations = self
|
||||
@@ -185,10 +180,10 @@ impl OidcConfiguration {
|
||||
})
|
||||
.collect();
|
||||
|
||||
registrations = registrations.with_static_registrations(static_registrations);
|
||||
registration_data.static_registrations = Some(static_registrations);
|
||||
}
|
||||
|
||||
Ok(registrations)
|
||||
Ok(registration_data)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,8 +196,6 @@ pub enum OidcError {
|
||||
NotSupported,
|
||||
#[error("Unable to use OIDC as the supplied client metadata is invalid.")]
|
||||
MetadataInvalid,
|
||||
#[error("Failed to use the supplied registrations file path.")]
|
||||
RegistrationsPathInvalid,
|
||||
#[error("The supplied callback URL used to complete OIDC is invalid.")]
|
||||
CallbackUrlInvalid,
|
||||
#[error("The OIDC login was cancelled by the user.")]
|
||||
@@ -228,17 +221,6 @@ impl From<SdkOAuthError> for OidcError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OAuthRegistrationStoreError> for OidcError {
|
||||
fn from(e: OAuthRegistrationStoreError) -> OidcError {
|
||||
match e {
|
||||
OAuthRegistrationStoreError::NotAFilePath | OAuthRegistrationStoreError::File(_) => {
|
||||
OidcError::RegistrationsPathInvalid
|
||||
}
|
||||
_ => OidcError::Generic { message: e.to_string() },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Error> for OidcError {
|
||||
fn from(e: Error) -> OidcError {
|
||||
match e {
|
||||
|
||||
@@ -408,10 +408,10 @@ impl Client {
|
||||
oidc_configuration: &OidcConfiguration,
|
||||
prompt: Option<OidcPrompt>,
|
||||
) -> Result<Arc<OAuthAuthorizationData>, OidcError> {
|
||||
let registrations = oidc_configuration.registrations().await?;
|
||||
let registration_data = oidc_configuration.registration_data()?;
|
||||
let redirect_uri = oidc_configuration.redirect_uri()?;
|
||||
|
||||
let mut url_builder = self.inner.oauth().login(registrations.into(), redirect_uri, None);
|
||||
let mut url_builder = self.inner.oauth().login(redirect_uri, None, Some(registration_data));
|
||||
|
||||
if let Some(prompt) = prompt {
|
||||
url_builder = url_builder.prompt(vec![prompt.into()]);
|
||||
|
||||
@@ -755,13 +755,12 @@ impl ClientBuilder {
|
||||
}
|
||||
})?;
|
||||
|
||||
let registrations = oidc_configuration
|
||||
.registrations()
|
||||
.await
|
||||
let registration_data = oidc_configuration
|
||||
.registration_data()
|
||||
.map_err(|_| HumanQrLoginError::OidcMetadataInvalid)?;
|
||||
|
||||
let oauth = client.inner.oauth();
|
||||
let login = oauth.login_with_qr_code(&qr_code_data.inner, registrations.into());
|
||||
let login = oauth.login_with_qr_code(&qr_code_data.inner, Some(®istration_data));
|
||||
|
||||
let mut progress = login.subscribe_to_progress();
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ use ruma::{
|
||||
use tracing::{info, instrument};
|
||||
use url::Url;
|
||||
|
||||
use super::{ClientRegistrationMethod, OAuth, OAuthError};
|
||||
use super::{ClientRegistrationData, OAuth, OAuthError};
|
||||
use crate::{authentication::oauth::AuthorizationValidationData, Result};
|
||||
|
||||
/// Builder type used to configure optional settings for authorization with an
|
||||
@@ -34,7 +34,7 @@ use crate::{authentication::oauth::AuthorizationValidationData, Result};
|
||||
#[allow(missing_debug_implementations)]
|
||||
pub struct OAuthAuthCodeUrlBuilder {
|
||||
oauth: OAuth,
|
||||
registration_method: ClientRegistrationMethod,
|
||||
registration_data: Option<ClientRegistrationData>,
|
||||
scopes: Vec<Scope>,
|
||||
device_id: OwnedDeviceId,
|
||||
redirect_uri: Url,
|
||||
@@ -45,14 +45,14 @@ pub struct OAuthAuthCodeUrlBuilder {
|
||||
impl OAuthAuthCodeUrlBuilder {
|
||||
pub(super) fn new(
|
||||
oauth: OAuth,
|
||||
registration_method: ClientRegistrationMethod,
|
||||
scopes: Vec<Scope>,
|
||||
device_id: OwnedDeviceId,
|
||||
redirect_uri: Url,
|
||||
registration_data: Option<ClientRegistrationData>,
|
||||
) -> Self {
|
||||
Self {
|
||||
oauth,
|
||||
registration_method,
|
||||
registration_data,
|
||||
scopes,
|
||||
device_id,
|
||||
redirect_uri,
|
||||
@@ -93,19 +93,12 @@ impl OAuthAuthCodeUrlBuilder {
|
||||
/// request fails.
|
||||
#[instrument(target = "matrix_sdk::client", skip_all)]
|
||||
pub async fn build(self) -> Result<OAuthAuthorizationData, OAuthError> {
|
||||
let Self {
|
||||
oauth,
|
||||
registration_method,
|
||||
scopes,
|
||||
device_id,
|
||||
redirect_uri,
|
||||
prompt,
|
||||
login_hint,
|
||||
} = self;
|
||||
let Self { oauth, registration_data, scopes, device_id, redirect_uri, prompt, login_hint } =
|
||||
self;
|
||||
|
||||
let server_metadata = oauth.server_metadata().await?;
|
||||
|
||||
oauth.use_registration_method(&server_metadata, ®istration_method).await?;
|
||||
oauth.use_registration_data(&server_metadata, registration_data.as_ref()).await?;
|
||||
|
||||
let data = oauth.data().expect("OAuth 2.0 data should be set after registration");
|
||||
info!(
|
||||
|
||||
@@ -31,8 +31,6 @@ use ruma::{
|
||||
|
||||
#[cfg(feature = "e2e-encryption")]
|
||||
pub use super::cross_process::CrossProcessRefreshLockError;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use super::registration_store::OAuthRegistrationStoreError;
|
||||
|
||||
/// An error when interacting with the OAuth 2.0 authorization server.
|
||||
pub type OAuthRequestError<T> =
|
||||
@@ -258,11 +256,6 @@ pub enum OAuthClientRegistrationError {
|
||||
/// Deserialization of the registration response failed.
|
||||
#[error("failed to deserialize registration response: {0}")]
|
||||
FromJson(serde_json::Error),
|
||||
|
||||
/// Failed to access or store the registration in the store.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[error("failed to use registration store: {0}")]
|
||||
Store(#[from] OAuthRegistrationStoreError),
|
||||
}
|
||||
|
||||
/// Error response returned by server after requesting an authorization code.
|
||||
|
||||
@@ -176,8 +176,6 @@ mod oidc_discovery;
|
||||
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
|
||||
pub mod qrcode;
|
||||
pub mod registration;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod registration_store;
|
||||
#[cfg(all(test, not(target_arch = "wasm32")))]
|
||||
mod tests;
|
||||
|
||||
@@ -185,8 +183,6 @@ mod tests;
|
||||
use self::cross_process::{CrossProcessRefreshLockGuard, CrossProcessRefreshManager};
|
||||
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
|
||||
use self::qrcode::LoginWithQrCode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub use self::registration_store::OAuthRegistrationStore;
|
||||
pub use self::{
|
||||
account_management_url::{AccountManagementActionFull, AccountManagementUrlBuilder},
|
||||
auth_code_builder::{OAuthAuthCodeUrlBuilder, OAuthAuthorizationData},
|
||||
@@ -351,6 +347,15 @@ impl OAuth {
|
||||
/// the private cross-signing keys and the backup key from the existing
|
||||
/// device to the new device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data scanned from a QR code.
|
||||
///
|
||||
/// * `registration_data` - The data to restore or register the client with
|
||||
/// the server. If this is not provided, an error will occur unless
|
||||
/// [`OAuth::register_client()`] or [`OAuth::restore_registered_client()`]
|
||||
/// was called previously.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
@@ -384,11 +389,12 @@ impl OAuth {
|
||||
/// .await?;
|
||||
///
|
||||
/// let oauth = client.oauth();
|
||||
/// let metadata: Raw<ClientMetadata> = client_metadata();
|
||||
/// let client_metadata: Raw<ClientMetadata> = client_metadata();
|
||||
/// let registration_data = client_metadata.into();
|
||||
///
|
||||
/// // Subscribing to the progress is necessary since we need to input the check
|
||||
/// // code on the existing device.
|
||||
/// let login = oauth.login_with_qr_code(&qr_code_data, metadata.into());
|
||||
/// let login = oauth.login_with_qr_code(&qr_code_data, Some(®istration_data));
|
||||
/// let mut progress = login.subscribe_to_progress();
|
||||
///
|
||||
/// // Create a task which will show us the progress and tell us the check
|
||||
@@ -420,73 +426,44 @@ impl OAuth {
|
||||
pub fn login_with_qr_code<'a>(
|
||||
&'a self,
|
||||
data: &'a QrCodeData,
|
||||
registration_method: ClientRegistrationMethod,
|
||||
registration_data: Option<&'a ClientRegistrationData>,
|
||||
) -> LoginWithQrCode<'a> {
|
||||
LoginWithQrCode::new(&self.client, registration_method, data)
|
||||
LoginWithQrCode::new(&self.client, data, registration_data)
|
||||
}
|
||||
|
||||
/// Restore or register the OAuth 2.0 client for the server with the given
|
||||
/// metadata, with the given [`OAuthRegistrationStore`].
|
||||
///
|
||||
/// If there is a client ID in the store, it is used to restore the client.
|
||||
/// Otherwise, the client is registered with the metadata in the store.
|
||||
///
|
||||
/// Returns an error if there is an error while accessing the store, or
|
||||
/// while registering the client.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
async fn restore_or_register_client(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
registrations: &OAuthRegistrationStore,
|
||||
) -> std::result::Result<(), OAuthClientRegistrationError> {
|
||||
if let Some(client_id) = registrations.client_id(&server_metadata.issuer).await? {
|
||||
self.restore_registered_client(server_metadata.issuer.clone(), client_id);
|
||||
|
||||
tracing::info!("OAuth 2.0 configuration loaded from disk.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
tracing::info!("Registering this client for OAuth 2.0.");
|
||||
let response = self.register_client_inner(server_metadata, ®istrations.metadata).await?;
|
||||
|
||||
tracing::info!("Persisting OAuth 2.0 registration data.");
|
||||
registrations
|
||||
.set_and_write_client_id(response.client_id, server_metadata.issuer.clone())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Restore or register the OAuth 2.0 client for the server with the given
|
||||
/// metadata, with the given [`ClientRegistrationMethod`].
|
||||
/// metadata, with the given optional [`ClientRegistrationData`].
|
||||
///
|
||||
/// If we already have a client ID, this is a noop.
|
||||
///
|
||||
/// Returns an error if there was a problem using the registration method.
|
||||
async fn use_registration_method(
|
||||
async fn use_registration_data(
|
||||
&self,
|
||||
server_metadata: &AuthorizationServerMetadata,
|
||||
method: &ClientRegistrationMethod,
|
||||
data: Option<&ClientRegistrationData>,
|
||||
) -> std::result::Result<(), OAuthError> {
|
||||
if self.client_id().is_some() {
|
||||
tracing::info!("OAuth 2.0 is already configured.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
match method {
|
||||
ClientRegistrationMethod::None => return Err(OAuthError::NotRegistered),
|
||||
ClientRegistrationMethod::ClientId(client_id) => {
|
||||
let Some(data) = data else {
|
||||
return Err(OAuthError::NotRegistered);
|
||||
};
|
||||
|
||||
if let Some(static_registrations) = &data.static_registrations {
|
||||
let client_id = static_registrations
|
||||
.get(&self.client.homeserver())
|
||||
.or_else(|| static_registrations.get(&server_metadata.issuer));
|
||||
|
||||
if let Some(client_id) = client_id {
|
||||
self.restore_registered_client(server_metadata.issuer.clone(), client_id.clone());
|
||||
}
|
||||
ClientRegistrationMethod::Metadata(client_metadata) => {
|
||||
self.register_client_inner(server_metadata, client_metadata).await?;
|
||||
}
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
ClientRegistrationMethod::Store(registrations) => {
|
||||
self.restore_or_register_client(server_metadata, registrations).await?
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
self.register_client_inner(server_metadata, &data.metadata).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -923,18 +900,12 @@ impl OAuth {
|
||||
|
||||
/// Login via OAuth 2.0 with the Authorization Code flow.
|
||||
///
|
||||
/// This should be called after [`OAuth::register_client()`] or
|
||||
/// [`OAuth::restore_registered_client()`].
|
||||
///
|
||||
/// [`OAuth::finish_login()`] must be called once the user has been
|
||||
/// redirected to the `redirect_uri`. [`OAuth::abort_login()`] should be
|
||||
/// called instead if the authorization should be aborted before completion.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `registration_method` - The method to restore or register the client
|
||||
/// with the server.
|
||||
///
|
||||
/// * `redirect_uri` - The URI where the end user will be redirected after
|
||||
/// authorizing the login. It must be one of the redirect URIs sent in the
|
||||
/// client metadata during registration.
|
||||
@@ -944,16 +915,19 @@ impl OAuth {
|
||||
/// device ID from a previous login call. Note that this should be done
|
||||
/// only if the client also holds the corresponding encryption keys.
|
||||
///
|
||||
/// * `registration_data` - The data to restore or register the client with
|
||||
/// the server. If this is not provided, an error will occur unless
|
||||
/// [`OAuth::register_client()`] or [`OAuth::restore_registered_client()`]
|
||||
/// was called previously.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use anyhow::anyhow;
|
||||
/// use matrix_sdk::{
|
||||
/// authentication::oauth::registration::ClientMetadata,
|
||||
/// ruma::serde::Raw,
|
||||
/// Client,
|
||||
/// authentication::oauth::OAuthRegistrationStore,
|
||||
/// };
|
||||
/// # use ruma::serde::Raw;
|
||||
/// # use matrix_sdk::authentication::oauth::registration::ClientMetadata;
|
||||
/// # let homeserver = unimplemented!();
|
||||
/// # let redirect_uri = unimplemented!();
|
||||
/// # let issuer_info = unimplemented!();
|
||||
@@ -964,13 +938,10 @@ impl OAuth {
|
||||
/// # _ = async {
|
||||
/// # let client = Client::new(homeserver).await?;
|
||||
/// let oauth = client.oauth();
|
||||
/// let client_metadata: Raw<ClientMetadata> = client_metadata();
|
||||
/// let registration_data = client_metadata.into();
|
||||
///
|
||||
/// let registration_store = OAuthRegistrationStore::new(
|
||||
/// store_path,
|
||||
/// client_metadata()
|
||||
/// ).await?;
|
||||
///
|
||||
/// let auth_data = oauth.login(registration_store.into(), redirect_uri, None)
|
||||
/// let auth_data = oauth.login(redirect_uri, None, Some(registration_data))
|
||||
/// .build()
|
||||
/// .await?;
|
||||
///
|
||||
@@ -980,7 +951,7 @@ impl OAuth {
|
||||
/// oauth.finish_login(redirected_to_uri.into()).await?;
|
||||
///
|
||||
/// // The session tokens can be persisted from the
|
||||
/// // `Client::session_tokens()` method.
|
||||
/// // `OAuth::full_session()` method.
|
||||
///
|
||||
/// // You can now make requests to the Matrix API.
|
||||
/// let _me = client.whoami().await?;
|
||||
@@ -988,18 +959,18 @@ impl OAuth {
|
||||
/// ```
|
||||
pub fn login(
|
||||
&self,
|
||||
registration_method: ClientRegistrationMethod,
|
||||
redirect_uri: Url,
|
||||
device_id: Option<OwnedDeviceId>,
|
||||
registration_data: Option<ClientRegistrationData>,
|
||||
) -> OAuthAuthCodeUrlBuilder {
|
||||
let (scopes, device_id) = Self::login_scopes(device_id);
|
||||
|
||||
OAuthAuthCodeUrlBuilder::new(
|
||||
self.clone(),
|
||||
registration_method,
|
||||
scopes.to_vec(),
|
||||
device_id,
|
||||
redirect_uri,
|
||||
registration_data,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1547,47 +1518,32 @@ fn hash_str(x: &str) -> impl fmt::LowerHex {
|
||||
sha2::Sha256::new().chain_update(x).finalize()
|
||||
}
|
||||
|
||||
/// The available methods to register or restore a client.
|
||||
#[derive(Debug)]
|
||||
pub enum ClientRegistrationMethod {
|
||||
/// No registration will be done.
|
||||
///
|
||||
/// This should only be set if [`OAuth::register_client()`] or
|
||||
/// [`OAuth::restore_registered_client()`] was already called before.
|
||||
None,
|
||||
/// Data to register or restore a client.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientRegistrationData {
|
||||
/// The metadata to use to register the client when using dynamic client
|
||||
/// registration.
|
||||
pub metadata: Raw<ClientMetadata>,
|
||||
|
||||
/// The given client ID will be used.
|
||||
/// Static registrations for servers that don't support dynamic registration
|
||||
/// but provide a client ID out-of-band.
|
||||
///
|
||||
/// This will call [`OAuth::restore_registered_client()`] internally.
|
||||
ClientId(ClientId),
|
||||
|
||||
/// The client will register using dynamic client registration, with the
|
||||
/// given metadata.
|
||||
///
|
||||
/// This will call [`OAuth::register_client()`] internally.
|
||||
Metadata(Raw<ClientMetadata>),
|
||||
|
||||
/// Use an [`OAuthRegistrationStore`] to handle registrations.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Store(OAuthRegistrationStore),
|
||||
/// The keys of the map should be the URLs of the homeservers, but keys
|
||||
/// using `issuer` URLs are also supported.
|
||||
pub static_registrations: Option<HashMap<Url, ClientId>>,
|
||||
}
|
||||
|
||||
impl From<ClientId> for ClientRegistrationMethod {
|
||||
fn from(value: ClientId) -> Self {
|
||||
Self::ClientId(value)
|
||||
impl ClientRegistrationData {
|
||||
/// Construct a [`ClientRegistrationData`] with the given metadata and no
|
||||
/// static registrations.
|
||||
pub fn new(metadata: Raw<ClientMetadata>) -> Self {
|
||||
Self { metadata, static_registrations: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Raw<ClientMetadata>> for ClientRegistrationMethod {
|
||||
impl From<Raw<ClientMetadata>> for ClientRegistrationData {
|
||||
fn from(value: Raw<ClientMetadata>) -> Self {
|
||||
Self::Metadata(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl From<OAuthRegistrationStore> for ClientRegistrationMethod {
|
||||
fn from(value: OAuthRegistrationStore) -> Self {
|
||||
Self::Store(value)
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ use super::{
|
||||
#[cfg(doc)]
|
||||
use crate::authentication::oauth::OAuth;
|
||||
use crate::{
|
||||
authentication::oauth::{ClientRegistrationMethod, OAuthError},
|
||||
authentication::oauth::{ClientRegistrationData, OAuthError},
|
||||
Client,
|
||||
};
|
||||
|
||||
@@ -83,7 +83,7 @@ pub enum LoginProgress {
|
||||
#[derive(Debug)]
|
||||
pub struct LoginWithQrCode<'a> {
|
||||
client: &'a Client,
|
||||
registration_method: ClientRegistrationMethod,
|
||||
registration_data: Option<&'a ClientRegistrationData>,
|
||||
qr_code_data: &'a QrCodeData,
|
||||
state: SharedObservable<LoginProgress>,
|
||||
}
|
||||
@@ -270,10 +270,10 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
|
||||
impl<'a> LoginWithQrCode<'a> {
|
||||
pub(crate) fn new(
|
||||
client: &'a Client,
|
||||
registration_method: ClientRegistrationMethod,
|
||||
qr_code_data: &'a QrCodeData,
|
||||
registration_data: Option<&'a ClientRegistrationData>,
|
||||
) -> LoginWithQrCode<'a> {
|
||||
LoginWithQrCode { client, registration_method, qr_code_data, state: Default::default() }
|
||||
LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
|
||||
}
|
||||
|
||||
async fn establish_secure_channel(
|
||||
@@ -299,7 +299,7 @@ impl<'a> LoginWithQrCode<'a> {
|
||||
) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
|
||||
let oauth = self.client.oauth();
|
||||
let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
|
||||
oauth.use_registration_method(&server_metadata, &self.registration_method).await?;
|
||||
oauth.use_registration_data(&server_metadata, self.registration_data).await?;
|
||||
|
||||
Ok(server_metadata)
|
||||
}
|
||||
@@ -465,7 +465,8 @@ mod test {
|
||||
let qr_code = alice.qr_code_data().clone();
|
||||
|
||||
let oauth = bob.oauth();
|
||||
let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into());
|
||||
let registration_data = mock_client_metadata().into();
|
||||
let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data));
|
||||
let mut updates = login_bob.subscribe_to_progress();
|
||||
|
||||
let updates_task = tokio::spawn(async move {
|
||||
@@ -552,7 +553,8 @@ mod test {
|
||||
let qr_code = alice.qr_code_data().clone();
|
||||
|
||||
let oauth = bob.oauth();
|
||||
let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into());
|
||||
let registration_data = mock_client_metadata().into();
|
||||
let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data));
|
||||
let mut updates = login_bob.subscribe_to_progress();
|
||||
|
||||
let _updates_task = tokio::spawn(async move {
|
||||
@@ -675,7 +677,8 @@ mod test {
|
||||
let qr_code = alice.qr_code_data().clone();
|
||||
|
||||
let oauth = bob.oauth();
|
||||
let login_bob = oauth.login_with_qr_code(&qr_code, mock_client_metadata().into());
|
||||
let registration_data = mock_client_metadata().into();
|
||||
let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data));
|
||||
let mut updates = login_bob.subscribe_to_progress();
|
||||
|
||||
let _updates_task = tokio::spawn(async move {
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! OAuth 2.0 client registration store.
|
||||
//!
|
||||
//! This module provides a way to persist OAuth 2.0 client registrations outside
|
||||
//! of the state store. This is useful when using a `Client` with an in-memory
|
||||
//! store or when different store paths are used for multi-account support
|
||||
//! within the same app, and those accounts need to share the same OAuth 2.0
|
||||
//! client registration.
|
||||
|
||||
use std::{collections::HashMap, io::ErrorKind, path::PathBuf};
|
||||
|
||||
use oauth2::ClientId;
|
||||
use ruma::serde::Raw;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::fs;
|
||||
use url::Url;
|
||||
|
||||
use super::ClientMetadata;
|
||||
|
||||
/// Errors that can occur when using the [`OAuthRegistrationStore`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OAuthRegistrationStoreError {
|
||||
/// The supplied path is not a file path.
|
||||
#[error("supplied registrations path is not a file path")]
|
||||
NotAFilePath,
|
||||
/// An error occurred when reading from or writing to the file.
|
||||
#[error(transparent)]
|
||||
File(#[from] std::io::Error),
|
||||
/// An error occurred when serializing the registration data.
|
||||
#[error("failed to serialize registration data: {0}")]
|
||||
IntoJson(serde_json::Error),
|
||||
/// An error occurred when deserializing the registration data.
|
||||
#[error("failed to deserialize registration data: {0}")]
|
||||
FromJson(serde_json::Error),
|
||||
}
|
||||
|
||||
/// An API to store and restore OAuth 2.0 client registrations.
|
||||
///
|
||||
/// This stores dynamic client registrations in a file, and accepts "static"
|
||||
/// client registrations via
|
||||
/// [`OAuthRegistrationStore::with_static_registrations()`], for servers that
|
||||
/// don't support dynamic client registration.
|
||||
///
|
||||
/// If the client metadata passed to this API changes, the previous
|
||||
/// registrations that were stored in the file are invalidated, allowing to
|
||||
/// re-register with the new metadata.
|
||||
///
|
||||
/// The purpose of storing client IDs outside of the state store or separate
|
||||
/// from the user's session is that it allows to reuse the same client ID
|
||||
/// between user sessions on the same server.
|
||||
#[derive(Debug)]
|
||||
pub struct OAuthRegistrationStore {
|
||||
/// The path of the file where the registrations are stored.
|
||||
pub(super) file_path: PathBuf,
|
||||
/// The metadata used to register the client.
|
||||
/// This is used to check if the client needs to be re-registered.
|
||||
pub(super) metadata: Raw<ClientMetadata>,
|
||||
/// Pre-configured registrations for use with issuers that don't support
|
||||
/// dynamic client registration.
|
||||
static_registrations: Option<HashMap<Url, ClientId>>,
|
||||
}
|
||||
|
||||
/// The underlying data serialized into the registration file.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct FrozenRegistrationData {
|
||||
/// The metadata used to register the client.
|
||||
metadata: Raw<ClientMetadata>,
|
||||
/// All of the registrations this client has made as a HashMap of issuer URL
|
||||
/// to client ID.
|
||||
dynamic_registrations: HashMap<Url, ClientId>,
|
||||
}
|
||||
|
||||
impl OAuthRegistrationStore {
|
||||
/// Creates a new registration store.
|
||||
///
|
||||
/// This method creates the `file`'s parent directory if it doesn't exist.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - A file path where the registrations will be stored. This
|
||||
/// previously took a directory and stored the registrations with the path
|
||||
/// `supplied_directory/oidc/registrations.json`.
|
||||
///
|
||||
/// * `metadata` - The metadata used to register the client. If this changes
|
||||
/// compared to the value stored in the file, any stored registrations
|
||||
/// will be invalidated so the client can re-register with the new data.
|
||||
pub async fn new(
|
||||
file: PathBuf,
|
||||
metadata: Raw<ClientMetadata>,
|
||||
) -> Result<Self, OAuthRegistrationStoreError> {
|
||||
let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
|
||||
fs::create_dir_all(parent).await?;
|
||||
|
||||
Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
|
||||
}
|
||||
|
||||
/// Add static registrations to the store.
|
||||
///
|
||||
/// Static registrations are used for servers that don't support dynamic
|
||||
/// registration but provide a client ID out-of-band.
|
||||
///
|
||||
/// These registrations are not stored in the file and must be provided each
|
||||
/// time.
|
||||
pub fn with_static_registrations(
|
||||
mut self,
|
||||
static_registrations: HashMap<Url, ClientId>,
|
||||
) -> Self {
|
||||
self.static_registrations = Some(static_registrations);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the client ID registered for a particular issuer or `None` if a
|
||||
/// registration hasn't been made.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `issuer` - The issuer to look up.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the file could not be read, or if the data in the
|
||||
/// file could not be deserialized.
|
||||
pub async fn client_id(
|
||||
&self,
|
||||
issuer: &Url,
|
||||
) -> Result<Option<ClientId>, OAuthRegistrationStoreError> {
|
||||
if let Some(client_id) =
|
||||
self.static_registrations.as_ref().and_then(|registrations| registrations.get(issuer))
|
||||
{
|
||||
return Ok(Some(client_id.clone()));
|
||||
}
|
||||
|
||||
let data = self.read_registration_data().await?;
|
||||
Ok(data.and_then(|mut data| data.dynamic_registrations.remove(issuer)))
|
||||
}
|
||||
|
||||
/// Stores a new client ID registration for a particular issuer.
|
||||
///
|
||||
/// If a client ID has already been stored for the given issuer, this will
|
||||
/// overwrite the old value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `client_id` - The client ID obtained after registration.
|
||||
///
|
||||
/// * `issuer` - The issuer associated with the client ID.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the file could not be read from or written to, or if
|
||||
/// the data in the file could not be (de)serialized.
|
||||
pub async fn set_and_write_client_id(
|
||||
&self,
|
||||
client_id: ClientId,
|
||||
issuer: Url,
|
||||
) -> Result<(), OAuthRegistrationStoreError> {
|
||||
let mut data = self.read_registration_data().await?.unwrap_or_else(|| {
|
||||
tracing::info!("Generating new OAuth 2.0 client registration data");
|
||||
FrozenRegistrationData {
|
||||
metadata: self.metadata.clone(),
|
||||
dynamic_registrations: Default::default(),
|
||||
}
|
||||
});
|
||||
data.dynamic_registrations.insert(issuer, client_id);
|
||||
|
||||
let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?;
|
||||
fs::write(&self.file_path, contents).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The persisted registration data.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the file could not be read, or if the data in the
|
||||
/// file could not be deserialized.
|
||||
async fn read_registration_data(
|
||||
&self,
|
||||
) -> Result<Option<FrozenRegistrationData>, OAuthRegistrationStoreError> {
|
||||
let contents = match fs::read(&self.file_path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(error) => {
|
||||
if error.kind() == ErrorKind::NotFound {
|
||||
// The file doesn't exist so there is no data.
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Forward the error.
|
||||
return Err(error.into());
|
||||
}
|
||||
};
|
||||
|
||||
let registration_data: FrozenRegistrationData =
|
||||
serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?;
|
||||
|
||||
if registration_data.metadata.json().get() != self.metadata.json().get() {
|
||||
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(registration_data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use matrix_sdk_test::async_test;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::*;
|
||||
use crate::authentication::oauth::registration::{ApplicationType, Localized, OAuthGrantType};
|
||||
|
||||
#[async_test]
|
||||
async fn test_oauth_registration_store() {
|
||||
// Given a fresh registration store with a single static registration.
|
||||
let dir = tempdir().unwrap();
|
||||
let registrations_file = dir.path().join("oauth").join("registrations.json");
|
||||
|
||||
let static_url = Url::parse("https://example.com").unwrap();
|
||||
let static_id = ClientId::new("static_client_id".to_owned());
|
||||
let dynamic_url = Url::parse("https://example.org").unwrap();
|
||||
let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
|
||||
|
||||
let mut static_registrations = HashMap::new();
|
||||
static_registrations.insert(static_url.clone(), static_id.clone());
|
||||
|
||||
let oidc_metadata = mock_metadata("Example".to_owned());
|
||||
|
||||
let registrations = OAuthRegistrationStore::new(registrations_file, oidc_metadata)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_static_registrations(static_registrations);
|
||||
|
||||
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone()));
|
||||
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None);
|
||||
|
||||
// When a dynamic registration is added.
|
||||
registrations
|
||||
.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Then the dynamic registration should be stored and the static registration
|
||||
// should be unaffected.
|
||||
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id));
|
||||
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id));
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_change_of_metadata() {
|
||||
// Given a single registration with an example app name.
|
||||
let dir = tempdir().unwrap();
|
||||
let registrations_file = dir.path().join("oidc").join("registrations.json");
|
||||
|
||||
let static_url = Url::parse("https://example.com").unwrap();
|
||||
let static_id = ClientId::new("static_client_id".to_owned());
|
||||
let dynamic_url = Url::parse("https://example.org").unwrap();
|
||||
let dynamic_id = ClientId::new("dynamic_client_id".to_owned());
|
||||
|
||||
let oidc_metadata = mock_metadata("Example".to_owned());
|
||||
|
||||
let mut static_registrations = HashMap::new();
|
||||
static_registrations.insert(static_url.clone(), static_id.clone());
|
||||
|
||||
let registrations = OAuthRegistrationStore::new(registrations_file.clone(), oidc_metadata)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_static_registrations(static_registrations.clone());
|
||||
registrations
|
||||
.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id.clone()));
|
||||
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), Some(dynamic_id));
|
||||
|
||||
// When the app name changes.
|
||||
let new_oidc_metadata = mock_metadata("New App".to_owned());
|
||||
|
||||
let registrations = OAuthRegistrationStore::new(registrations_file, new_oidc_metadata)
|
||||
.await
|
||||
.unwrap()
|
||||
.with_static_registrations(static_registrations);
|
||||
|
||||
// Then the dynamic registrations are cleared.
|
||||
assert_eq!(registrations.client_id(&dynamic_url).await.unwrap(), None);
|
||||
assert_eq!(registrations.client_id(&static_url).await.unwrap(), Some(static_id));
|
||||
}
|
||||
|
||||
fn mock_metadata(client_name: String) -> Raw<ClientMetadata> {
|
||||
let callback_url = Url::parse("https://example.org/login/callback").unwrap();
|
||||
let client_uri = Url::parse("https://example.org/").unwrap();
|
||||
|
||||
let mut metadata = ClientMetadata::new(
|
||||
ApplicationType::Web,
|
||||
vec![OAuthGrantType::AuthorizationCode { redirect_uris: vec![callback_url] }],
|
||||
Localized::new(client_uri, None),
|
||||
);
|
||||
metadata.client_name = Some(Localized::new(client_name, None));
|
||||
|
||||
Raw::new(&metadata).unwrap()
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,6 @@ use ruma::{
|
||||
owned_device_id, user_id, DeviceId, ServerName,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
use tokio::sync::broadcast::error::TryRecvError;
|
||||
use url::Url;
|
||||
use wiremock::{
|
||||
@@ -18,14 +17,13 @@ use wiremock::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
registration_store::OAuthRegistrationStore, AuthorizationCode, AuthorizationError,
|
||||
AuthorizationResponse, OAuth, OAuthAuthorizationData, OAuthError, RedirectUriQueryParseError,
|
||||
UrlOrQuery,
|
||||
AuthorizationCode, AuthorizationError, AuthorizationResponse, OAuth, OAuthAuthorizationData,
|
||||
OAuthError, RedirectUriQueryParseError, UrlOrQuery,
|
||||
};
|
||||
use crate::{
|
||||
authentication::oauth::{
|
||||
error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError},
|
||||
AuthorizationValidationData, ClientRegistrationMethod, OAuthAuthorizationCodeError,
|
||||
AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError,
|
||||
},
|
||||
test_utils::{
|
||||
client::{
|
||||
@@ -40,7 +38,7 @@ use crate::{
|
||||
|
||||
const REDIRECT_URI_STRING: &str = "http://127.0.0.1:6778/oauth/callback";
|
||||
|
||||
async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAuthRegistrationStore)>
|
||||
async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, ClientRegistrationData)>
|
||||
{
|
||||
let server = MatrixMockServer::new().await;
|
||||
server.mock_who_am_i().ok().named("whoami").mount().await;
|
||||
@@ -53,12 +51,7 @@ async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, OAu
|
||||
let client = server.client_builder().unlogged().build().await;
|
||||
let client_metadata = mock_client_metadata();
|
||||
|
||||
let registrations_path =
|
||||
tempdir().unwrap().path().join("matrix-sdk-oauth").join("registrations.json");
|
||||
let registrations =
|
||||
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
|
||||
|
||||
Ok((client.oauth(), server, mock_redirect_uri(), registrations))
|
||||
Ok((client.oauth(), server, mock_redirect_uri(), client_metadata.into()))
|
||||
}
|
||||
|
||||
/// Check the URL in the given authorization data.
|
||||
@@ -154,35 +147,23 @@ async fn check_authorization_url(
|
||||
#[async_test]
|
||||
async fn test_high_level_login() -> anyhow::Result<()> {
|
||||
// Given a fresh environment.
|
||||
let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap();
|
||||
let registrations_path = registrations.file_path.clone();
|
||||
let client_metadata = registrations.metadata.clone();
|
||||
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
|
||||
|
||||
assert!(oauth.issuer().is_none());
|
||||
assert!(oauth.client_id().is_none());
|
||||
|
||||
// When getting the OIDC login URL.
|
||||
let authorization_data = oauth
|
||||
.login(registrations.into(), redirect_uri.clone(), None)
|
||||
.login(redirect_uri.clone(), None, Some(registration_data))
|
||||
.prompt(vec![Prompt::Create])
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Then the client should be configured correctly.
|
||||
assert_let!(Some(issuer) = oauth.issuer());
|
||||
assert!(oauth.issuer().is_some());
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// The client ID should have been saved in the registration file.
|
||||
let registrations =
|
||||
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
|
||||
assert_eq!(
|
||||
registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()),
|
||||
Some("test_client_id")
|
||||
);
|
||||
|
||||
check_authorization_url(&authorization_data, &oauth, issuer, None, Some("create"), None).await;
|
||||
|
||||
// When completing the login with a valid callback.
|
||||
redirect_uri.set_query(Some(&format!("code=42&state={}", authorization_data.state.secret())));
|
||||
|
||||
@@ -195,24 +176,14 @@ async fn test_high_level_login() -> anyhow::Result<()> {
|
||||
#[async_test]
|
||||
async fn test_high_level_login_cancellation() -> anyhow::Result<()> {
|
||||
// Given a client ready to complete login.
|
||||
let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap();
|
||||
let registrations_path = registrations.file_path.clone();
|
||||
let client_metadata = registrations.metadata.clone();
|
||||
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
|
||||
|
||||
let authorization_data =
|
||||
oauth.login(registrations.into(), redirect_uri.clone(), None).build().await.unwrap();
|
||||
oauth.login(redirect_uri.clone(), None, Some(registration_data)).build().await.unwrap();
|
||||
|
||||
assert_let!(Some(issuer) = oauth.issuer());
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// The client ID should have been saved in the registration file.
|
||||
let registrations =
|
||||
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
|
||||
assert_eq!(
|
||||
registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()),
|
||||
Some("test_client_id")
|
||||
);
|
||||
|
||||
check_authorization_url(&authorization_data, &oauth, issuer, None, None, None).await;
|
||||
|
||||
// When completing login with a cancellation callback.
|
||||
@@ -235,24 +206,14 @@ async fn test_high_level_login_cancellation() -> anyhow::Result<()> {
|
||||
#[async_test]
|
||||
async fn test_high_level_login_invalid_state() -> anyhow::Result<()> {
|
||||
// Given a client ready to complete login.
|
||||
let (oauth, _server, mut redirect_uri, registrations) = mock_environment().await.unwrap();
|
||||
let registrations_path = registrations.file_path.clone();
|
||||
let client_metadata = registrations.metadata.clone();
|
||||
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
|
||||
|
||||
let authorization_data =
|
||||
oauth.login(registrations.into(), redirect_uri.clone(), None).build().await.unwrap();
|
||||
oauth.login(redirect_uri.clone(), None, Some(registration_data)).build().await.unwrap();
|
||||
|
||||
assert_let!(Some(issuer) = oauth.issuer());
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// The client ID should have been saved in the registration file.
|
||||
let registrations =
|
||||
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
|
||||
assert_eq!(
|
||||
registrations.client_id(issuer).await.unwrap().as_ref().map(|id| id.as_str()),
|
||||
Some("test_client_id")
|
||||
);
|
||||
|
||||
check_authorization_url(&authorization_data, &oauth, issuer, None, None, None).await;
|
||||
|
||||
// When completing login with an old/tampered state.
|
||||
@@ -286,16 +247,14 @@ async fn test_login_url() -> anyhow::Result<()> {
|
||||
let redirect_uri = Url::parse(redirect_uri_str)?;
|
||||
|
||||
// No extra parameters.
|
||||
let authorization_data = oauth
|
||||
.login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone()))
|
||||
.build()
|
||||
.await?;
|
||||
let authorization_data =
|
||||
oauth.login(redirect_uri.clone(), Some(device_id.clone()), None).build().await?;
|
||||
check_authorization_url(&authorization_data, &oauth, &issuer, Some(&device_id), None, None)
|
||||
.await;
|
||||
|
||||
// With prompt parameter.
|
||||
let authorization_data = oauth
|
||||
.login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone()))
|
||||
.login(redirect_uri.clone(), Some(device_id.clone()), None)
|
||||
.prompt(vec![Prompt::Create])
|
||||
.build()
|
||||
.await?;
|
||||
@@ -311,7 +270,7 @@ async fn test_login_url() -> anyhow::Result<()> {
|
||||
|
||||
// With user_id_hint parameter.
|
||||
let authorization_data = oauth
|
||||
.login(ClientRegistrationMethod::None, redirect_uri.clone(), Some(device_id.clone()))
|
||||
.login(redirect_uri.clone(), Some(device_id.clone()), None)
|
||||
.user_id_hint(user_id!("@joe:example.org"))
|
||||
.build()
|
||||
.await?;
|
||||
@@ -719,7 +678,7 @@ async fn test_server_metadata() {
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_client_registration_methods() {
|
||||
async fn test_client_registration_data() {
|
||||
let server = MatrixMockServer::new().await;
|
||||
let oauth_server = server.oauth();
|
||||
let server_metadata = oauth_server.server_metadata();
|
||||
@@ -727,29 +686,26 @@ async fn test_client_registration_methods() {
|
||||
// Without registration we get an error.
|
||||
let client = server.client_builder().unlogged().build().await;
|
||||
let oauth = client.oauth();
|
||||
let res =
|
||||
oauth.use_registration_method(&server_metadata, &ClientRegistrationMethod::None).await;
|
||||
let res = oauth.use_registration_data(&server_metadata, None).await;
|
||||
assert_matches!(res, Err(OAuthError::NotRegistered));
|
||||
assert_eq!(oauth.client_id(), None);
|
||||
|
||||
// With a client ID.
|
||||
oauth
|
||||
.use_registration_method(
|
||||
&server_metadata,
|
||||
&ClientRegistrationMethod::ClientId(mock_client_id()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// With a static registration.
|
||||
let registration_data = ClientRegistrationData {
|
||||
metadata: mock_client_metadata(),
|
||||
static_registrations: Some([(server_metadata.issuer.clone(), mock_client_id())].into()),
|
||||
};
|
||||
oauth.use_registration_data(&server_metadata, Some(®istration_data)).await.unwrap();
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// If we call it again we get the same client ID.
|
||||
oauth
|
||||
.use_registration_method(
|
||||
&server_metadata,
|
||||
&ClientRegistrationMethod::ClientId(ClientId::new("other_client_id".to_owned())),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// If we call it again, it's a noop.
|
||||
let registration_data = ClientRegistrationData {
|
||||
metadata: mock_client_metadata(),
|
||||
static_registrations: Some(
|
||||
[(server_metadata.issuer.clone(), ClientId::new("other_client_id".to_owned()))].into(),
|
||||
),
|
||||
};
|
||||
oauth.use_registration_data(&server_metadata, Some(®istration_data)).await.unwrap();
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// With metadata we register a new client ID.
|
||||
@@ -765,40 +721,6 @@ async fn test_client_registration_methods() {
|
||||
.mount()
|
||||
.await;
|
||||
|
||||
oauth
|
||||
.use_registration_method(
|
||||
&server_metadata,
|
||||
&ClientRegistrationMethod::Metadata(client_metadata.clone()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// With registration store we register a new client ID.
|
||||
let client = server.client_builder().unlogged().build().await;
|
||||
let oauth = client.oauth();
|
||||
|
||||
let registrations_path =
|
||||
tempdir().unwrap().path().join("matrix-sdk-oauth").join("registrations.json");
|
||||
let registrations =
|
||||
OAuthRegistrationStore::new(registrations_path, client_metadata).await.unwrap();
|
||||
let store_method = ClientRegistrationMethod::Store(registrations);
|
||||
|
||||
oauth_server
|
||||
.mock_registration()
|
||||
.ok()
|
||||
.mock_once()
|
||||
.named("registration_with_store")
|
||||
.mount()
|
||||
.await;
|
||||
|
||||
oauth.use_registration_method(&server_metadata, &store_method).await.unwrap();
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
|
||||
// With same registration store, the client ID is loaded from the store.
|
||||
let client = server.client_builder().unlogged().build().await;
|
||||
let oauth = client.oauth();
|
||||
|
||||
oauth.use_registration_method(&server_metadata, &store_method).await.unwrap();
|
||||
oauth.use_registration_data(&server_metadata, Some(&client_metadata.into())).await.unwrap();
|
||||
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ use matrix_sdk::{
|
||||
authentication::oauth::{
|
||||
error::OAuthClientRegistrationError,
|
||||
registration::{ApplicationType, ClientMetadata, Localized, OAuthGrantType},
|
||||
AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError,
|
||||
OAuthRegistrationStore, OAuthSession, UrlOrQuery, UserSession,
|
||||
AccountManagementActionFull, ClientId, OAuthAuthorizationData, OAuthError, OAuthSession,
|
||||
UrlOrQuery, UserSession,
|
||||
},
|
||||
config::SyncSettings,
|
||||
encryption::{recovery::RecoveryState, CrossSigningResetAuthType},
|
||||
@@ -139,7 +139,7 @@ impl OAuthCli {
|
||||
let (client, client_session) = build_client(data_dir).await?;
|
||||
let cli = Self { client, restored: false, session_file };
|
||||
|
||||
if let Err(error) = cli.register_and_login(data_dir).await {
|
||||
if let Err(error) = cli.register_and_login().await {
|
||||
if error.downcast_ref::<OAuthError>().is_some_and(|error| {
|
||||
matches!(
|
||||
error,
|
||||
@@ -179,28 +179,18 @@ impl OAuthCli {
|
||||
|
||||
/// Register the client and log in the user via the OAuth 2.0 Authorization
|
||||
/// Code flow.
|
||||
async fn register_and_login(&self, data_dir: &Path) -> anyhow::Result<()> {
|
||||
async fn register_and_login(&self) -> anyhow::Result<()> {
|
||||
let oauth = self.client.oauth();
|
||||
|
||||
// We create a loop here so the user can retry if an error happens.
|
||||
loop {
|
||||
// The registrations store allows to store the client ID separately
|
||||
// and reuse it between sessions, so we don't need to
|
||||
// re-register each time.
|
||||
// We put it in the parent directory because we remove the data directory on
|
||||
// logout.
|
||||
let registration_file =
|
||||
data_dir.parent().expect("directory has a parent").join("oauth_registrations");
|
||||
let registration_store =
|
||||
OAuthRegistrationStore::new(registration_file, client_metadata()).await?;
|
||||
|
||||
// Here we spawn a server to listen on the loopback interface. Another option
|
||||
// would be to register a custom URI scheme with the system and handle
|
||||
// the redirect when the custom URI scheme is opened.
|
||||
let (redirect_uri, server_handle) = LocalServerBuilder::new().spawn().await?;
|
||||
|
||||
let OAuthAuthorizationData { url, .. } =
|
||||
oauth.login(registration_store.into(), redirect_uri, None).build().await?;
|
||||
oauth.login(redirect_uri, None, Some(client_metadata().into())).build().await?;
|
||||
|
||||
let query_string =
|
||||
use_auth_url(&url, server_handle).await.map(|query| query.0).unwrap_or_default();
|
||||
|
||||
@@ -112,10 +112,10 @@ async fn login(proxy: Option<Url>) -> Result<()> {
|
||||
|
||||
let client = client.build().await?;
|
||||
|
||||
let metadata = client_metadata();
|
||||
let registration_data = client_metadata().into();
|
||||
let oauth = client.oauth();
|
||||
|
||||
let login_client = oauth.login_with_qr_code(&data, metadata.into());
|
||||
let login_client = oauth.login_with_qr_code(&data, Some(®istration_data));
|
||||
let mut subscriber = login_client.subscribe_to_progress();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
|
||||
Reference in New Issue
Block a user