mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-03 21:45:51 -04:00
refactor(oauth): Use tokio::fs APIs instead of spawn_blocking
Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
committed by
Ivan Enderlin
parent
ecdc68aa1c
commit
22cbce82ce
@@ -20,18 +20,12 @@
|
||||
//! within the same app, and those accounts need to share the same OAuth 2.0
|
||||
//! client registration.
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
fs::File,
|
||||
io::{BufReader, BufWriter, ErrorKind},
|
||||
path::PathBuf,
|
||||
};
|
||||
use std::{collections::HashMap, io::ErrorKind, path::PathBuf};
|
||||
|
||||
use oauth2::ClientId;
|
||||
use ruma::serde::Raw;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::task::spawn_blocking;
|
||||
use tokio::fs;
|
||||
use url::Url;
|
||||
|
||||
use super::ClientMetadata;
|
||||
@@ -107,14 +101,10 @@ impl OAuthRegistrationStore {
|
||||
file: PathBuf,
|
||||
metadata: Raw<ClientMetadata>,
|
||||
) -> Result<Self, OAuthRegistrationStoreError> {
|
||||
spawn_blocking(move || {
|
||||
let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
|
||||
fs::create_dir_all(parent)?;
|
||||
let parent = file.parent().ok_or(OAuthRegistrationStoreError::NotAFilePath)?;
|
||||
fs::create_dir_all(parent).await?;
|
||||
|
||||
Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
|
||||
})
|
||||
.await
|
||||
.expect("Task join error")
|
||||
Ok(OAuthRegistrationStore { file_path: file, metadata, static_registrations: None })
|
||||
}
|
||||
|
||||
/// Add static registrations to the store.
|
||||
@@ -186,14 +176,10 @@ impl OAuthRegistrationStore {
|
||||
});
|
||||
data.dynamic_registrations.insert(issuer, client_id);
|
||||
|
||||
let file_path = self.file_path.clone();
|
||||
let contents = serde_json::to_vec(&data).map_err(OAuthRegistrationStoreError::IntoJson)?;
|
||||
fs::write(&self.file_path, contents).await?;
|
||||
|
||||
spawn_blocking(move || {
|
||||
let writer = BufWriter::new(File::create(&file_path)?);
|
||||
serde_json::to_writer(writer, &data).map_err(OAuthRegistrationStoreError::IntoJson)
|
||||
})
|
||||
.await
|
||||
.expect("Task join error")
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The persisted registration data.
|
||||
@@ -205,40 +191,28 @@ impl OAuthRegistrationStore {
|
||||
async fn read_registration_data(
|
||||
&self,
|
||||
) -> Result<Option<FrozenRegistrationData>, OAuthRegistrationStoreError> {
|
||||
let file_path = self.file_path.clone();
|
||||
|
||||
let registration_data = spawn_blocking(move || {
|
||||
let file = match File::open(&file_path) {
|
||||
Ok(file) => file,
|
||||
Err(error) => {
|
||||
if error.kind() == ErrorKind::NotFound {
|
||||
// The file doesn't exist so there is no data.
|
||||
return Ok::<_, OAuthRegistrationStoreError>(None);
|
||||
}
|
||||
|
||||
// Forward the error.
|
||||
return Err(error.into());
|
||||
let contents = match fs::read(&self.file_path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(error) => {
|
||||
if error.kind() == ErrorKind::NotFound {
|
||||
// The file doesn't exist so there is no data.
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let registration_data: FrozenRegistrationData =
|
||||
serde_json::from_reader(BufReader::new(file))
|
||||
.map_err(OAuthRegistrationStoreError::FromJson)?;
|
||||
|
||||
Ok(Some(registration_data))
|
||||
})
|
||||
.await
|
||||
.expect("Task join error")?;
|
||||
|
||||
Ok(registration_data.filter(|registration_data| {
|
||||
let changed = registration_data.metadata.json().get() != self.metadata.json().get();
|
||||
|
||||
if changed {
|
||||
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
|
||||
// Forward the error.
|
||||
return Err(error.into());
|
||||
}
|
||||
};
|
||||
|
||||
!changed
|
||||
}))
|
||||
let registration_data: FrozenRegistrationData =
|
||||
serde_json::from_slice(&contents).map_err(OAuthRegistrationStoreError::FromJson)?;
|
||||
|
||||
if registration_data.metadata.json().get() != self.metadata.json().get() {
|
||||
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(registration_data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user