refactor(oauth): Don't ignore errors when reading the file of the OAuthRegistrationStore

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-03-19 11:54:14 +01:00
committed by Ivan Enderlin
parent fbd4a7dc38
commit 0b0f84b784
2 changed files with 72 additions and 36 deletions

View File

@@ -502,7 +502,10 @@ impl OAuth {
return Ok(());
};
if self.load_client_registration(issuer, &registrations) {
if self
.load_client_registration(issuer, &registrations)
.map_err(OAuthClientRegistrationError::from)?
{
tracing::info!("OAuth 2.0 configuration loaded from disk.");
return Ok(());
}
@@ -539,14 +542,14 @@ impl OAuth {
&self,
issuer: Url,
registrations: &OAuthRegistrationStore,
) -> bool {
let Some(client_id) = registrations.client_id(&issuer) else {
return false;
) -> Result<bool, OAuthRegistrationStoreError> {
let Some(client_id) = registrations.client_id(&issuer)? else {
return Ok(false);
};
self.restore_registered_client(issuer, client_id);
true
Ok(true)
}
/// The OAuth 2.0 authorization server used for authorization.

View File

@@ -24,7 +24,7 @@ use std::{
collections::HashMap,
fs,
fs::File,
io::{BufReader, BufWriter},
io::{BufReader, BufWriter, ErrorKind},
path::{Path, PathBuf},
};
@@ -47,6 +47,9 @@ pub enum OAuthRegistrationStoreError {
/// An error occurred when serializing the registration data.
#[error("failed to serialize registration data: {0}")]
IntoJson(serde_json::Error),
/// An error occurred when deserializing the registration data.
#[error("failed to deserialize registration data: {0}")]
FromJson(serde_json::Error),
}
/// An API to store and restore OAuth 2.0 client registrations.
@@ -114,25 +117,45 @@ impl OAuthRegistrationStore {
/// Returns the client ID registered for a particular issuer or `None` if a
/// registration hasn't been made.
pub fn client_id(&self, issuer: &Url) -> Option<ClientId> {
///
/// # Arguments
///
/// * `issuer` - The issuer to look up.
///
/// # Errors
///
/// Returns an error if the file could not be read, or if the data in the
/// file could not be deserialized.
pub fn client_id(&self, issuer: &Url) -> Result<Option<ClientId>, OAuthRegistrationStoreError> {
if let Some(client_id) = self.static_registrations.get(issuer) {
return Some(client_id.clone());
return Ok(Some(client_id.clone()));
}
let mut data = self.read_registration_data()?;
data.dynamic_registrations.remove(issuer)
let data = self.read_registration_data()?;
Ok(data.and_then(|mut data| data.dynamic_registrations.remove(issuer)))
}
/// Stores a new client ID registration for a particular issuer.
///
/// If a client ID has already been stored for the given issuer, this will
/// overwrite the old value.
///
/// # Arguments
///
/// * `client_id` - The client ID obtained after registration.
///
/// * `issuer` - The issuer associated with the client ID.
///
/// # Errors
///
/// Returns an error if the file could not be read from or written to, or if
/// the data in the file could not be (de)serialized.
pub fn set_and_write_client_id(
&self,
client_id: ClientId,
issuer: Url,
) -> Result<(), OAuthRegistrationStoreError> {
let mut data = self.read_registration_data().unwrap_or_else(|| {
let mut data = self.read_registration_data()?.unwrap_or_else(|| {
tracing::info!("Generating new OAuth 2.0 client registration data");
FrozenRegistrationData {
metadata: self.metadata.clone(),
@@ -145,28 +168,38 @@ impl OAuthRegistrationStore {
serde_json::to_writer(writer, &data).map_err(OAuthRegistrationStoreError::IntoJson)
}
/// Returns the persisted registration data.
fn read_registration_data(&self) -> Option<FrozenRegistrationData> {
let reader = BufReader::new(
File::open(&self.file_path)
.map_err(|error| {
tracing::warn!("Failed to load registrations file: {error}");
})
.ok()?,
);
/// The persisted registration data.
///
/// # Errors
///
/// Returns an error if the file could not be read, or if the data in the
/// file could not be deserialized.
fn read_registration_data(
&self,
) -> Result<Option<FrozenRegistrationData>, OAuthRegistrationStoreError> {
let file = match File::open(&self.file_path) {
Ok(file) => file,
Err(error) => {
if error.kind() == ErrorKind::NotFound {
// The file doesn't exist so there is no data.
return Ok(None);
}
let registration_data: FrozenRegistrationData = serde_json::from_reader(reader)
.map_err(|error| {
tracing::warn!("Failed to deserialize registrations file: {error}");
})
.ok()?;
// Forward the error.
return Err(error.into());
}
};
let registration_data: FrozenRegistrationData =
serde_json::from_reader(BufReader::new(file))
.map_err(OAuthRegistrationStoreError::FromJson)?;
if registration_data.metadata.json().get() != self.metadata.json().get() {
tracing::warn!("Metadata mismatch, ignoring any stored registrations.");
return None;
tracing::info!("Metadata mismatch, ignoring any stored registrations.");
return Ok(None);
}
Some(registration_data)
Ok(Some(registration_data))
}
}
@@ -197,16 +230,16 @@ mod tests {
OAuthRegistrationStore::new(&registrations_file, oidc_metadata, static_registrations)
.unwrap();
assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url), None);
assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url).unwrap(), None);
// When a dynamic registration is added.
registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
// Then the dynamic registration should be stored and the static registration
// should be unaffected.
assert_eq!(registrations.client_id(&static_url), Some(static_id));
assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id));
assert_eq!(registrations.client_id(&dynamic_url).unwrap(), Some(dynamic_id));
}
#[test]
@@ -233,8 +266,8 @@ mod tests {
.unwrap();
registrations.set_and_write_client_id(dynamic_id.clone(), dynamic_url.clone()).unwrap();
assert_eq!(registrations.client_id(&static_url), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url), Some(dynamic_id));
assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id.clone()));
assert_eq!(registrations.client_id(&dynamic_url).unwrap(), Some(dynamic_id));
// When the app name changes.
let new_oidc_metadata = mock_metadata("New App".to_owned());
@@ -247,8 +280,8 @@ mod tests {
.unwrap();
// Then the dynamic registrations are cleared.
assert_eq!(registrations.client_id(&dynamic_url), None);
assert_eq!(registrations.client_id(&static_url), Some(static_id));
assert_eq!(registrations.client_id(&dynamic_url).unwrap(), None);
assert_eq!(registrations.client_id(&static_url).unwrap(), Some(static_id));
}
fn mock_metadata(client_name: String) -> Raw<ClientMetadata> {