From 22cbce82ce94632ee4a0aacdfd83cfb83eeb0e16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Commaille?= Date: Fri, 21 Mar 2025 09:53:53 +0100 Subject: [PATCH] refactor(oauth): Use tokio::fs APIs instead of spawn_blocking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Commaille --- .../oauth/registration_store.rs | 78 +++++++------------ 1 file changed, 26 insertions(+), 52 deletions(-) diff --git a/crates/matrix-sdk/src/authentication/oauth/registration_store.rs b/crates/matrix-sdk/src/authentication/oauth/registration_store.rs index 16da0a125..0b1cb6a52 100644 --- a/crates/matrix-sdk/src/authentication/oauth/registration_store.rs +++ b/crates/matrix-sdk/src/authentication/oauth/registration_store.rs @@ -20,18 +20,12 @@ //! within the same app, and those accounts need to share the same OAuth 2.0 //! client registration. -use std::{ - collections::HashMap, - fs, - fs::File, - io::{BufReader, BufWriter, ErrorKind}, - path::PathBuf, -}; +use std::{collections::HashMap, io::ErrorKind, path::PathBuf}; use oauth2::ClientId; use ruma::serde::Raw; use serde::{Deserialize, Serialize}; -use tokio::task::spawn_blocking; +use tokio::fs; use url::Url; use super::ClientMetadata; @@ -107,14 +101,10 @@ impl OAuthRegistrationStore { file: PathBuf, metadata: Raw, ) -> Result { - spawn_blocking(move || { - let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?; - fs::create_dir_all(parent)?; + let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?; + fs::create_dir_all(parent).await?; - Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None }) - }) - .await - .expect("Task join error") + Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None }) } /// Add static registrations to the store. @@ -186,14 +176,10 @@ impl OAuthRegistrationStore { }); data.dynamic_registrations.insert(issuer, client_id); - let file_path = self.file_path.clone(); + let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?; + fs::write(&self.file_path, contents).await?; - spawn_blocking(move || { - let writer = BufWriter::new(File::create(&file_path)?); - serde_json::to_writer(writer, &data).map_err(OAuthRegistrationStoreError::IntoJson) - }) - .await - .expect("Task join error") + Ok(()) } /// The persisted registration data. @@ -205,40 +191,28 @@ impl OAuthRegistrationStore { async fn read_registration_data( &self, ) -> Result, OAuthRegistrationStoreError> { - let file_path = self.file_path.clone(); - - let registration_data = spawn_blocking(move || { - let file = match File::open(&file_path) { - Ok(file) => file, - Err(error) => { - if error.kind() == ErrorKind::NotFound { - // The file doesn't exist so there is no data. - return Ok::<_, OAuthRegistrationStoreError>(None); - } - - // Forward the error. - return Err(error.into()); + let contents = match fs::read(&self.file_path).await { + Ok(contents) => contents, + Err(error) => { + if error.kind() == ErrorKind::NotFound { + // The file doesn't exist so there is no data. + return Ok(None); } - }; - let registration_data: FrozenRegistrationData = - serde_json::from_reader(BufReader::new(file)) - .map_err(OAuthRegistrationStoreError::FromJson)?; - - Ok(Some(registration_data)) - }) - .await - .expect("Task join error")?; - - Ok(registration_data.filter(|registration_data| { - let changed = registration_data.metadata.json().get() != self.metadata.json().get(); - - if changed { - tracing::info!("Metadata mismatch, ignoring any stored registrations."); + // Forward the error. + return Err(error.into()); } + }; - !changed - })) + let registration_data: FrozenRegistrationData = + serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?; + + if registration_data.metadata.json().get() != self.metadata.json().get() { + tracing::info!("Metadata mismatch, ignoring any stored registrations."); + Ok(None) + } else { + Ok(Some(registration_data)) + } } }