deduplicating handler: use a new QueryState data structure

This commit is contained in:
Benjamin Bouvier
2024-06-11 16:57:25 +02:00
parent ac0992953e
commit ad63d28cfa

View File

@@ -24,7 +24,20 @@ use tokio::sync::Mutex;
use crate::{Error, Result};
type DeduplicatedRequestMap<Key> = Mutex<BTreeMap<Key, Arc<Mutex<Option<Result<(), ()>>>>>>;
/// State machine for the state of a query deduplicated by the
/// [`DeduplicatingHandler`].
enum QueryState {
/// The query hasn't completed yet. This doesn't mean it hasn't *started*
/// yet, but rather that it couldn't get to completion: some
/// intermediate steps might have run.
NotFinishedYet,
/// The query has completed with an `Ok` result.
Success,
/// The query has completed with an `Err` result.
Failure,
}
type DeduplicatedRequestMap<Key> = Mutex<BTreeMap<Key, Arc<Mutex<QueryState>>>>;
/// Handler that properly deduplicates function calls given a key uniquely
/// identifying the call kind, and will properly report error upwards in case
@@ -62,45 +75,51 @@ impl<Key: Clone + Ord + std::hash::Hash> DeduplicatingHandler<Key> {
let mut request_guard = request_mutex.lock().await;
return match *request_guard {
Some(Ok(())) => {
QueryState::Success => {
// The query completed with a success: forward this success.
Ok(())
}
Some(Err(())) => {
QueryState::Failure => {
// The query completed with an error, but we don't know what it is; report
// there was an error.
Err(Error::ConcurrentRequestFailed)
}
None => {
// The query hasn't completed, it could have been cancelled. Repeat it.
self.run_code(key, code, &mut *request_guard).await
QueryState::NotFinishedYet => {
// If we could take a hold onto the mutex without it being in the success or
// failure state, then the query hasn't completed (e.g. it could have been
// cancelled). Repeat it.
//
// Note: there might be other waiters for the deduplicated result; they will
// still be waiting for the mutex above, since the mutex is obtained for at
// most one holder at the same time.
self.run_code(key, code, &mut request_guard).await
}
};
}
// Start at the `None` state to indicate we haven't completed the request yet.
let request_mutex = Arc::new(Mutex::new(None));
let request_mutex = Arc::new(Mutex::new(QueryState::NotFinishedYet));
map.insert(key.clone(), request_mutex.clone());
let mut request_guard = request_mutex.lock().await;
drop(map);
self.run_code(key, code, &mut *request_guard).await
self.run_code(key, code, &mut request_guard).await
}
async fn run_code<'a, F: Future<Output = Result<()>> + SendOutsideWasm + 'a>(
&self,
key: Key,
code: F,
result: &mut Option<Result<(), ()>>,
result: &mut QueryState,
) -> Result<()> {
match code.await {
Ok(()) => {
// Mark the request as completed.
*result = Some(Ok(()));
*result = QueryState::Success;
self.inflight.lock().await.remove(&key);
@@ -109,7 +128,7 @@ impl<Key: Clone + Ord + std::hash::Hash> DeduplicatingHandler<Key> {
Err(err) => {
// Propagate the error state to other callers.
*result = Some(Err(()));
*result = QueryState::Failure;
// Remove the request from the in-flights set.
self.inflight.lock().await.remove(&key);