From 0b0f84b784b7fd0bcecdcd7c388f190c6e2159c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Wed, 19 Mar 2025 11:54:14 +0100 Subject: [PATCH] refactor(oauth): Don't ignore errors when reading the file of the OAuthRegistrationStore MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Commaille --- .../src/authentication/oauth/mod.rs | 13 ++- .../oauth/registration_store.rs | 95 +++++++++++++------ 2 files changed, 72 insertions(+), 36 deletions(-) diff --git a/crates/matrix-sdk/src/authentication/oauth/mod.rs b/crates/matrix-sdk/src/authentication/oauth/mod.rs index 5a9eed48b..c1ecc5028 100644 --- a/crates/matrix-sdk/src/authentication/oauth/mod.rs +++ b/crates/matrix-sdk/src/authentication/oauth/mod.rs @@ -502,7 +502,10 @@ impl OAuth { return Ok(()); }; - if self.load_client_registration(issuer, ®istrations) { + if self + .load_client_registration(issuer, ®istrations) + .map_err(OAuthClientRegistrationError::from)? + { tracing::info!("OAuth 2.0 configuration loaded from disk."); return Ok(()); } @@ -539,14 +542,14 @@ impl OAuth { &self, issuer: Url, registrations: &OAuthRegistrationStore, - ) -> bool { - let Some(client_id) = registrations.client_id(&issuer) else { - return false; + ) -> Result { + let Some(client_id) = registrations.client_id(&issuer)? else { + return Ok(false); }; self.restore_registered_client(issuer, client_id); - true + Ok(true) } /// The OAuth 2.0 authorization server used for authorization. diff --git a/crates/matrix-sdk/src/authentication/oauth/registration_store.rs b/crates/matrix-sdk/src/authentication/oauth/registration_store.rs index 59d30db71..b895a9635 100644 --- a/crates/matrix-sdk/src/authentication/oauth/registration_store.rs +++ b/crates/matrix-sdk/src/authentication/oauth/registration_store.rs @@ -24,7 +24,7 @@ use std::{ collections::HashMap, fs, fs::File, - io::{BufReader, BufWriter}, + io::{BufReader, BufWriter, ErrorKind}, path::{Path, PathBuf}, }; @@ -47,6 +47,9 @@ pub enum OAuthRegistrationStoreError { /// An error occurred when serializing the registration data. #[error("failed to serialize registration data: {0}")] IntoJson(serde_json::Error), + /// An error occurred when deserializing the registration data. + #[error("failed to deserialize registration data: {0}")] + FromJson(serde_json::Error), } /// An API to store and restore OAuth 2.0 client registrations. @@ -114,25 +117,45 @@ impl OAuthRegistrationStore { /// Returns the client ID registered for a particular issuer or `None` if a /// registration hasn't been made. - pub fn client_id(&self, issuer: &Url) -> Option { + /// + /// # Arguments + /// + /// * `issuer` - The issuer to look up. + /// + /// # Errors + /// + /// Returns an error if the file could not be read, or if the data in the + /// file could not be deserialized. + pub fn client_id(&self, issuer: &Url) -> Result, OAuthRegistrationStoreError> { if let Some(client_id) = self.static_registrations.get(issuer) { - return Some(client_id.clone()); + return Ok(Some(client_id.clone())); } - let mut data = self.read_registration_data()?; - data.dynamic_registrations.remove(issuer) + let data = self.read_registration_data()?; + Ok(data.and_then(|mut data| data.dynamic_registrations.remove(issuer))) } /// Stores a new client ID registration for a particular issuer. /// /// If a client ID has already been stored for the given issuer, this will /// overwrite the old value. + /// + /// # Arguments + /// + /// * `client_id` - The client ID obtained after registration. + /// + /// * `issuer` - The issuer associated with the client ID. + /// + /// # Errors + /// + /// Returns an error if the file could not be read from or written to, or if + /// the data in the file could not be (de)serialized. pub fn set_and_write_client_id( &self, client_id: ClientId, issuer: Url, ) -> Result<(), OAuthRegistrationStoreError> { - let mut data = self.read_registration_data().unwrap_or_else(|| { + let mut data = self.read_registration_data()?.unwrap_or_else(|| { tracing::info!("Generating new OAuth 2.0 client registration data"); FrozenRegistrationData { metadata: self.metadata.clone(), @@ -145,28 +168,38 @@ impl OAuthRegistrationStore { serde_json::to_writer(writer, &data).map_err(OAuthRegistrationStoreError::IntoJson) } - /// Returns the persisted registration data. - fn read_registration_data(&self) -> Option { - let reader = BufReader::new( - File::open(&self.file_path) - .map_err(|error| { - tracing::warn!("Failed to load registrations file: {error}"); - }) - .ok()?, - ); + /// The persisted registration data. + /// + /// # Errors + /// + /// Returns an error if the file could not be read, or if the data in the + /// file could not be deserialized. + fn read_registration_data( + &self, + ) -> Result, OAuthRegistrationStoreError> { + let file = match File::open(&self.file_path) { + Ok(file) => file, + Err(error) => { + if error.kind() == ErrorKind::NotFound { + // The file doesn't exist so there is no data. + return Ok(None); + } - let registration_data: FrozenRegistrationData = serde_json::from_reader(reader) - .map_err(|error| { - tracing::warn!("Failed to deserialize registrations file: {error}"); - }) - .ok()?; + // Forward the error. + return Err(error.into()); + } + }; + + let registration_data: FrozenRegistrationData = + serde_json::from_reader(BufReader::new(file)) + .map_err(OAuthRegistrationStoreError::FromJson)?; if registration_data.metadata.json().get() != self.metadata.json().get() { - tracing::warn!("Metadata mismatch, ignoring any stored registrations."); - return None; + tracing::info!("Metadata mismatch, ignoring any stored registrations."); + return Ok(None); } - Some(registration_data) + Ok(Some(registration_data)) } } @@ -197,16 +230,16 @@ mod tests { OAuthRegistrationStore::new(®istrations_file, oidc_metadata, static_registrations) .unwrap(); - assert_eq!(registrations.client_id(&static_url), Some(static_id.clone())); - assert_eq!(registrations.client_id(&dynamic_url), None); + assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id.clone())); + assert_eq!(registrations.client_id(&dynamic_url).unwrap(), None); // When a dynamic registration is added. registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap(); // Then the dynamic registration should be stored and the static registration // should be unaffected. - assert_eq!(registrations.client_id(&static_url), Some(static_id)); - assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id)); + assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id)); + assert_eq!(registrations.client_id(&dynamic_url).unwrap(), Some(dynamic_id)); } #[test] @@ -233,8 +266,8 @@ mod tests { .unwrap(); registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap(); - assert_eq!(registrations.client_id(&static_url), Some(static_id.clone())); - assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id)); + assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id.clone())); + assert_eq!(registrations.client_id(&dynamic_url).unwrap(), Some(dynamic_id)); // When the app name changes. let new_oidc_metadata = mock_metadata("New App".to_owned()); @@ -247,8 +280,8 @@ mod tests { .unwrap(); // Then the dynamic registrations are cleared. - assert_eq!(registrations.client_id(&dynamic_url), None); - assert_eq!(registrations.client_id(&static_url), Some(static_id)); + assert_eq!(registrations.client_id(&dynamic_url).unwrap(), None); + assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id)); } fn mock_metadata(client_name: String) -> Raw {