feat: have the SyncService own the EncryptionSyncPermit and add try_get_encryption_sync_permit

This commit is contained in:
Benjamin Bouvier
2023-09-04 16:05:37 +02:00
parent 9469b77741
commit de9b6d25cd
2 changed files with 31 additions and 6 deletions

View File

@@ -33,7 +33,7 @@ use thiserror::Error;
use tokio::{
sync::{
mpsc::{Receiver, Sender},
Mutex as AsyncMutex,
Mutex as AsyncMutex, OwnedMutexGuard,
},
task::{spawn, JoinHandle},
};
@@ -88,6 +88,13 @@ pub struct SyncService {
/// Task running the encryption sync.
encryption_sync_task: Arc<Mutex<Option<JoinHandle<()>>>>,
/// Global lock to allow using at most one `EncryptionSync` at all times.
///
/// This ensures that there's only one ever existing in the application's
/// lifetime (under the assumption that there is at most one
/// `SyncService` per application).
encryption_sync_permit: Arc<AsyncMutex<EncryptionSyncPermit>>,
/// Scheduler task ensuring proper termination.
///
/// This task is waiting for a `TerminationReport` from any of the other two
@@ -208,11 +215,9 @@ impl SyncService {
&self,
encryption_sync: Arc<EncryptionSync>,
sender: Sender<TerminationReport>,
sync_permit_guard: OwnedMutexGuard<EncryptionSyncPermit>,
) -> impl Future<Output = ()> {
async move {
let sync_permit = Arc::new(AsyncMutex::new(EncryptionSyncPermit::new()));
let sync_permit_guard = sync_permit.lock_owned().await;
let encryption_sync_stream = encryption_sync.sync(sync_permit_guard);
pin_mut!(encryption_sync_stream);
@@ -324,8 +329,12 @@ impl SyncService {
// Then, take care of the encryption sync.
if let Some(encryption_sync) = self.encryption_sync.clone() {
*self.encryption_sync_task.lock().unwrap() =
Some(spawn(self.spawn_encryption_sync(encryption_sync, sender.clone())));
let sync_permit_guard = self.encryption_sync_permit.clone().lock_owned().await;
*self.encryption_sync_task.lock().unwrap() = Some(spawn(self.spawn_encryption_sync(
encryption_sync,
sender.clone(),
sync_permit_guard,
)));
}
// Spawn the scheduler task.
@@ -389,6 +398,14 @@ impl SyncService {
Ok(())
}
/// Attempt to get a permit to use an `EncryptionSync` at a given time.
///
/// This ensures there is at most one [`EncryptionSync`] active at any time,
/// per application.
pub fn try_get_encryption_sync_permit(&self) -> Option<OwnedMutexGuard<EncryptionSyncPermit>> {
self.encryption_sync_permit.clone().try_lock_owned().ok()
}
}
enum TerminationOrigin {
@@ -476,6 +493,8 @@ impl SyncServiceBuilder {
/// the background. The resulting `SyncService` must be kept alive as
/// long as the sliding syncs are supposed to run.
pub async fn build(self) -> Result<SyncService, Error> {
let encryption_sync_permit = Arc::new(AsyncMutex::new(EncryptionSyncPermit::new()));
let (room_list, encryption_sync) = if self.with_encryption_sync {
let room_list = RoomListService::new(self.client.clone()).await?;
@@ -501,6 +520,7 @@ impl SyncServiceBuilder {
scheduler_sender: Mutex::new(None),
state: SharedObservable::new(State::Idle),
modifying_state: AsyncMutex::new(()),
encryption_sync_permit,
})
}
}

View File

@@ -77,16 +77,19 @@ async fn test_sync_service_state() -> anyhow::Result<()> {
assert_eq!(state_stream.get(), State::Idle);
assert!(server.received_requests().await.unwrap().is_empty());
assert_eq!(sync_service.task_states(), (false, false));
assert!(sync_service.try_get_encryption_sync_permit().is_some());
// After starting, the sync service is, well, running.
sync_service.start().await;
assert_next_matches!(state_stream, State::Running);
assert_eq!(sync_service.task_states(), (true, true));
assert!(sync_service.try_get_encryption_sync_permit().is_none());
// Restarting while started doesn't change the current state.
sync_service.start().await;
assert_pending!(state_stream);
assert_eq!(sync_service.task_states(), (true, true));
assert!(sync_service.try_get_encryption_sync_permit().is_none());
// Let the server respond a few times.
tokio::time::sleep(Duration::from_millis(300)).await;
@@ -95,6 +98,7 @@ async fn test_sync_service_state() -> anyhow::Result<()> {
sync_service.stop().await?;
assert_next_matches!(state_stream, State::Idle);
assert_eq!(sync_service.task_states(), (false, false));
assert!(sync_service.try_get_encryption_sync_permit().is_some());
let mut num_encryption_sync_requests: i32 = 0;
let mut num_room_list_requests = 0;
@@ -149,6 +153,7 @@ async fn test_sync_service_state() -> anyhow::Result<()> {
sync_service.start().await;
assert_next_matches!(state_stream, State::Running);
assert_eq!(sync_service.task_states(), (true, true));
assert!(sync_service.try_get_encryption_sync_permit().is_none());
tokio::time::sleep(Duration::from_millis(100)).await;