refactor(oauth): Use tokio::fs APIs instead of spawn_blocking

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-03-21 09:53:53 +01:00
committed by Ivan Enderlin
parent ecdc68aa1c
commit 22cbce82ce

View File

@@ -20,18 +20,12 @@
//! within the same app, and those accounts need to share the same OAuth 2.0
//! client registration.
use std::{
collections::HashMap,
fs,
fs::File,
io::{BufReader, BufWriter, ErrorKind},
path::PathBuf,
};
use std::{collections::HashMap, io::ErrorKind, path::PathBuf};
use oauth2::ClientId;
use ruma::serde::Raw;
use serde::{Deserialize, Serialize};
use tokio::task::spawn_blocking;
use tokio::fs;
use url::Url;
use super::ClientMetadata;
@@ -107,14 +101,10 @@ impl OAuthRegistrationStore {
file: PathBuf,
metadata: Raw<ClientMetadata>,
) -> Result<Self, OAuthRegistrationStoreError> {
spawn_blocking(move || {
let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
fs::create_dir_all(parent)?;
let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
fs::create_dir_all(parent).await?;
Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
})
.await
.expect("Task join error")
Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
}
/// Add static registrations to the store.
@@ -186,14 +176,10 @@ impl OAuthRegistrationStore {
});
data.dynamic_registrations.insert(issuer, client_id);
let file_path = self.file_path.clone();
let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?;
fs::write(&self.file_path, contents).await?;
spawn_blocking(move || {
let writer = BufWriter::new(File::create(&file_path)?);
serde_json::to_writer(writer, &data).map_err(OAuthRegistrationStoreError::IntoJson)
})
.await
.expect("Task join error")
Ok(())
}
/// The persisted registration data.
@@ -205,40 +191,28 @@ impl OAuthRegistrationStore {
async fn read_registration_data(
&self,
) -> Result<Option<FrozenRegistrationData>, OAuthRegistrationStoreError> {
let file_path = self.file_path.clone();
let registration_data = spawn_blocking(move || {
let file = match File::open(&file_path) {
Ok(file) => file,
Err(error) => {
if error.kind() == ErrorKind::NotFound {
// The file doesn't exist so there is no data.
return Ok::<_, OAuthRegistrationStoreError>(None);
}
// Forward the error.
return Err(error.into());
let contents = match fs::read(&self.file_path).await {
Ok(contents) => contents,
Err(error) => {
if error.kind() == ErrorKind::NotFound {
// The file doesn't exist so there is no data.
return Ok(None);
}
};
let registration_data: FrozenRegistrationData =
serde_json::from_reader(BufReader::new(file))
.map_err(OAuthRegistrationStoreError::FromJson)?;
Ok(Some(registration_data))
})
.await
.expect("Task join error")?;
Ok(registration_data.filter(|registration_data| {
let changed = registration_data.metadata.json().get() != self.metadata.json().get();
if changed {
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
// Forward the error.
return Err(error.into());
}
};
!changed
}))
let registration_data: FrozenRegistrationData =
serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?;
if registration_data.metadata.json().get() != self.metadata.json().get() {
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
Ok(None)
} else {
Ok(Some(registration_data))
}
}
}