feat(base): Allow to clone MediaService

We want to be able to send it to a new task, so the easiest way is to be able to clone it.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-02-02 12:03:12 +01:00
parent 4e1ae3d5e9
commit 8dc2ec9dc4
2 changed files with 41 additions and 28 deletions

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt;
use std::{fmt, sync::Arc};
use async_trait::async_trait;
use matrix_sdk_common::{locks::Mutex, AsyncTraitDeps};
@@ -28,6 +28,11 @@ use crate::{event_cache::store::EventCacheStoreError, media::MediaRequestParamet
/// [`EventCacheStore`]: crate::event_cache::store::EventCacheStore
#[derive(Debug)]
pub struct MediaService<Time: TimeProvider = DefaultTimeProvider> {
inner: Arc<MediaServiceInner<Time>>,
}
#[derive(Debug)]
struct MediaServiceInner<Time: TimeProvider = DefaultTimeProvider> {
/// The time provider.
time_provider: Time,
@@ -61,11 +66,13 @@ where
/// Construct a new `MediaService` with the given `TimeProvider` and an
/// empty `MediaRetentionPolicy`.
fn with_time_provider(time_provider: Time) -> Self {
Self {
let inner = MediaServiceInner {
time_provider,
policy: Mutex::new(MediaRetentionPolicy::empty()),
cleanup_guard: AsyncMutex::new(()),
}
};
Self { inner: Arc::new(inner) }
}
/// Restore the previous state of the [`MediaRetentionPolicy`] from data
@@ -78,10 +85,15 @@ where
/// * `policy` - The `MediaRetentionPolicy` that was persisted in the store.
pub fn restore(&self, policy: Option<MediaRetentionPolicy>) {
if let Some(policy) = policy {
*self.policy.lock() = policy;
*self.inner.policy.lock() = policy;
}
}
/// Get the current time from the inner [`TimeProvider`].
fn now(&self) -> SystemTime {
self.inner.time_provider.now()
}
/// Set the `MediaRetentionPolicy` of this service.
///
/// # Arguments
@@ -96,14 +108,14 @@ where
) -> Result<(), Store::Error> {
store.set_media_retention_policy_inner(policy).await?;
*self.policy.lock() = policy;
*self.inner.policy.lock() = policy;
Ok(())
}
/// Get the `MediaRetentionPolicy` of this service.
pub fn media_retention_policy(&self) -> MediaRetentionPolicy {
*self.policy.lock()
*self.inner.policy.lock()
}
/// Add a media file's content in the media store.
@@ -134,15 +146,7 @@ where
return Ok(());
}
store
.add_media_content_inner(
request,
content,
self.time_provider.now(),
policy,
ignore_policy,
)
.await
store.add_media_content_inner(request, content, self.now(), policy, ignore_policy).await
}
/// Set whether the current [`MediaRetentionPolicy`] should be ignored for
@@ -179,7 +183,7 @@ where
store: &Store,
request: &MediaRequestParameters,
) -> Result<Option<Vec<u8>>, Store::Error> {
store.get_media_content_inner(request, self.time_provider.now()).await
store.get_media_content_inner(request, self.now()).await
}
/// Get a media file's content associated to an `MxcUri` from the
@@ -195,7 +199,7 @@ where
store: &Store,
uri: &MxcUri,
) -> Result<Option<Vec<u8>>, Store::Error> {
store.get_media_content_for_uri_inner(uri, self.time_provider.now()).await
store.get_media_content_for_uri_inner(uri, self.now()).await
}
/// Clean up the media cache with the current `MediaRetentionPolicy`.
@@ -209,7 +213,7 @@ where
&self,
store: &Store,
) -> Result<(), Store::Error> {
let Ok(_guard) = self.cleanup_guard.try_lock() else {
let Ok(_guard) = self.inner.cleanup_guard.try_lock() else {
// There is another ongoing cleanup.
return Ok(());
};
@@ -221,7 +225,16 @@ where
return Ok(());
}
store.clean_up_media_cache_inner(policy, self.time_provider.now()).await
store.clean_up_media_cache_inner(policy, self.now()).await
}
}
impl<Time> Clone for MediaService<Time>
where
Time: TimeProvider,
{
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
@@ -683,7 +696,7 @@ mod tests {
assert_eq!(media_content.last_access, now);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
// Get media from request.
@@ -696,7 +709,7 @@ mod tests {
assert_eq!(media.last_access, now);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
// Get media from URI.
@@ -786,7 +799,7 @@ mod tests {
assert_eq!(media_content.last_access, now);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
// Get media from request.
@@ -799,7 +812,7 @@ mod tests {
assert_eq!(media.last_access, now);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
// Get media from URI.
@@ -812,7 +825,7 @@ mod tests {
assert_eq!(media.last_access, now);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
// Add big media, it will not work because it is bigger than the max file size.
@@ -865,7 +878,7 @@ mod tests {
assert_eq!(media.last_access, now);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
// Get media from URI.
@@ -881,7 +894,7 @@ mod tests {
assert_eq!(store.last_media_cleanup_time_inner().await.unwrap(), None);
let now = now + Duration::from_secs(60);
service.time_provider.set_now(now);
service.inner.time_provider.set_now(now);
store.reset_accessed();
service.clean_up_media_cache(&store).await.unwrap();

View File

@@ -78,7 +78,7 @@ const CHUNK_TYPE_GAP_TYPE_STRING: &str = "G";
pub struct SqliteEventCacheStore {
store_cipher: Option<Arc<StoreCipher>>,
pool: SqlitePool,
media_service: Arc<MediaService>,
media_service: MediaService,
}
#[cfg(not(tarpaulin_include))]
@@ -119,7 +119,7 @@ impl SqliteEventCacheStore {
let media_retention_policy = conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await?;
media_service.restore(media_retention_policy);
Ok(Self { store_cipher, pool, media_service: Arc::new(media_service) })
Ok(Self { store_cipher, pool, media_service })
}
fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>> {