sdk: add a way to reset the in-memory and on-disk caches for server capabilities

This required changing the OnceCell to RwLock<Option<>>, because a
OnceCell can't be reset to another value.
This commit is contained in:
Benjamin Bouvier
2024-07-04 19:21:00 +02:00
parent a76b6faa9f
commit 19fcae4e0b
3 changed files with 156 additions and 36 deletions

View File

@@ -442,6 +442,15 @@ impl Client {
let http_client = self.inner.http_client();
Ok(http_client.get(url).send().await?.text().await?)
}
/// Empty the server version and unstable features cache.
///
/// Since the SDK caches server capabilities (versions and unstable
/// features), it's possible to have a stale entry in the cache. This
/// functions makes it possible to force reset it.
pub async fn reset_server_capabilities(&self) -> Result<(), ClientError> {
Ok(self.inner.reset_server_capabilities().await?)
}
}
impl Client {

View File

@@ -239,10 +239,10 @@ pub(crate) struct ClientInner {
base_client: BaseClient,
/// The Matrix versions the server supports (well-known ones only)
server_versions: OnceCell<Box<[MatrixVersion]>>,
server_versions: RwLock<Option<Box<[MatrixVersion]>>>,
/// The unstable features and their on/off state on the server
unstable_features: OnceCell<BTreeMap<String, bool>>,
unstable_features: RwLock<Option<BTreeMap<String, bool>>>,
/// Collection of locks individual client methods might want to use, either
/// to ensure that only a single call to a method happens at once or to
@@ -324,8 +324,8 @@ impl ClientInner {
http_client,
base_client,
locks: Default::default(),
server_versions: OnceCell::new_with(server_versions),
unstable_features: OnceCell::new_with(unstable_features),
server_versions: RwLock::new(server_versions),
unstable_features: RwLock::new(unstable_features),
typing_notice_times: Default::default(),
event_handlers: Default::default(),
notification_handlers: Default::default(),
@@ -1430,7 +1430,7 @@ impl Client {
config,
homeserver,
access_token.as_deref(),
self.server_versions().await?,
&self.server_versions().await?,
send_progress,
)
.await
@@ -1509,20 +1509,25 @@ impl Client {
self.fetch_and_cache_server_capabilities().await
}
pub(crate) async fn server_versions(&self) -> HttpResult<&[MatrixVersion]> {
Ok(self
.inner
.server_versions
.get_or_try_init(|| async {
let (versions, unstable_features) =
self.load_or_fetch_server_capabilities().await?;
pub(crate) async fn server_versions(&self) -> HttpResult<Box<[MatrixVersion]>> {
if let Some(server_versions) = self.inner.server_versions.read().await.as_ref() {
return Ok(server_versions.clone());
}
// Fill the unstable features, while we're at it.
self.inner.unstable_features.get_or_init(|| async { unstable_features }).await;
let mut guard = self.inner.server_versions.write().await;
if let Some(server_versions) = guard.as_ref() {
return Ok(server_versions.clone());
}
HttpResult::Ok(versions)
})
.await?)
let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?;
*guard = Some(versions.clone());
drop(guard);
// Fill the unstable features, while we're at it.
*self.inner.unstable_features.write().await = Some(unstable_features);
Ok(versions)
}
/// Get unstable features from by fetching from the server or the cache.
@@ -1539,20 +1544,39 @@ impl Client {
/// let msc_x = unstable_features.get("msc_x").unwrap_or(&false);
/// # anyhow::Ok(()) };
/// ```
pub async fn unstable_features(&self) -> HttpResult<&BTreeMap<String, bool>> {
Ok(self
.inner
.unstable_features
.get_or_try_init(|| async {
let (versions, unstable_features) =
self.load_or_fetch_server_capabilities().await?;
pub async fn unstable_features(&self) -> HttpResult<BTreeMap<String, bool>> {
if let Some(unstable_features) = self.inner.unstable_features.read().await.as_ref() {
return Ok(unstable_features.clone());
}
// Fill the versions, while we're at it.
self.inner.server_versions.get_or_init(|| async { versions }).await;
let mut guard = self.inner.unstable_features.write().await;
if let Some(unstable_features) = guard.as_ref() {
return Ok(unstable_features.clone());
}
HttpResult::Ok(unstable_features)
})
.await?)
let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?;
*guard = Some(unstable_features.clone());
drop(guard);
// Fill the versions, while we're at it.
*self.inner.server_versions.write().await = Some(versions);
Ok(unstable_features)
}
/// Empty the server version and unstable features cache.
///
/// Since the SDK caches server capabilities (versions and unstable
/// features), it's possible to have a stale entry in the cache. This
/// functions makes it possible to force reset it.
pub async fn reset_server_capabilities(&self) -> Result<()> {
// Empty the in-memory caches.
*self.inner.server_versions.write().await = None;
*self.inner.unstable_features.write().await = None;
// Empty the store cache.
Ok(self.store().remove_kv_data(StateStoreDataKey::ServerCapabilities).await?)
}
/// Check whether MSC 4028 is enabled on the homeserver.
@@ -2161,8 +2185,8 @@ impl Client {
sliding_sync_proxy,
self.inner.http_client.clone(),
self.inner.base_client.clone_with_in_memory_state_store(),
self.inner.server_versions.get().cloned(),
self.inner.unstable_features.get().cloned(),
self.inner.server_versions.read().await.clone(),
self.inner.unstable_features.read().await.clone(),
self.inner.respect_login_well_known,
self.inner.event_cache.clone(),
self.inner.send_queue_data.clone(),
@@ -2236,7 +2260,10 @@ pub(crate) mod tests {
use std::{sync::Arc, time::Duration};
use assert_matches::assert_matches;
use matrix_sdk_base::RoomState;
use matrix_sdk_base::{
store::{MemoryStore, StoreConfig},
RoomState,
};
use matrix_sdk_test::{
async_test, test_json, JoinedRoomBuilder, StateTestEvent, SyncResponseBuilder,
DEFAULT_TEST_ROOM_ID,
@@ -2245,8 +2272,8 @@ pub(crate) mod tests {
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
use ruma::{
events::ignored_user_list::IgnoredUserListEventContent, owned_room_id, room_id, RoomId,
UserId,
api::MatrixVersion, events::ignored_user_list::IgnoredUserListEventContent, owned_room_id,
room_id, RoomId, ServerName, UserId,
};
use url::Url;
use wiremock::{
@@ -2609,4 +2636,88 @@ pub(crate) mod tests {
Arc::strong_count(&client.unwrap().inner)
);
}
#[async_test]
async fn test_server_capabilities_caching() {
let server = MockServer::start().await;
let server_url = server.uri();
let domain = server_url.strip_prefix("http://").unwrap();
let server_name = <&ServerName>::try_from(domain).unwrap();
Mock::given(method("GET"))
.and(path("/.well-known/matrix/client"))
.respond_with(ResponseTemplate::new(200).set_body_raw(
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
"application/json",
))
.named("well known mock")
.expect(2)
.mount(&server)
.await;
let versions_mock = Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
.named("first versions mock")
.expect(1)
.mount_as_scoped(&server)
.await;
let memory_store = Arc::new(MemoryStore::new());
let client = Client::builder()
.insecure_server_name_no_tls(server_name)
.store_config(StoreConfig::new().state_store(memory_store.clone()))
.build()
.await
.unwrap();
assert_eq!(client.server_versions().await.unwrap().len(), 1);
// This second call hits the in-memory cache.
assert!(client
.server_versions()
.await
.unwrap()
.iter()
.any(|version| *version == MatrixVersion::V1_0));
drop(client);
let client = Client::builder()
.insecure_server_name_no_tls(server_name)
.store_config(StoreConfig::new().state_store(memory_store.clone()))
.build()
.await
.unwrap();
// This third call hits the on-disk cache.
assert_eq!(
client.unstable_features().await.unwrap().get("org.matrix.e2e_cross_signing"),
Some(&true)
);
drop(versions_mock);
server.verify().await;
// Now, reset the cache, and observe the endpoint being called again once.
client.reset_server_capabilities().await.unwrap();
Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
.expect(1)
.named("second versions mock")
.mount(&server)
.await;
// Hits network again.
assert_eq!(client.server_versions().await.unwrap().len(), 1);
// Hits in-memory cache again.
assert!(client
.server_versions()
.await
.unwrap()
.iter()
.any(|version| *version == MatrixVersion::V1_0));
}
}

View File

@@ -130,13 +130,13 @@ impl MatrixAuth {
.try_into_http_request::<Vec<u8>>(
homeserver.as_str(),
SendAccessToken::None,
server_versions,
&server_versions,
)
} else {
sso_login::v3::Request::new(redirect_url.to_owned()).try_into_http_request::<Vec<u8>>(
homeserver.as_str(),
SendAccessToken::None,
server_versions,
&server_versions,
)
};