mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-19 06:04:31 -04:00
sdk: add a way to reset the in-memory and on-disk caches for server capabilities
This required changing the OnceCell to RwLock<Option<>>, because a OnceCell can't be reset to another value.
This commit is contained in:
@@ -442,6 +442,15 @@ impl Client {
|
||||
let http_client = self.inner.http_client();
|
||||
Ok(http_client.get(url).send().await?.text().await?)
|
||||
}
|
||||
|
||||
/// Empty the server version and unstable features cache.
|
||||
///
|
||||
/// Since the SDK caches server capabilities (versions and unstable
|
||||
/// features), it's possible to have a stale entry in the cache. This
|
||||
/// functions makes it possible to force reset it.
|
||||
pub async fn reset_server_capabilities(&self) -> Result<(), ClientError> {
|
||||
Ok(self.inner.reset_server_capabilities().await?)
|
||||
}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
|
||||
@@ -239,10 +239,10 @@ pub(crate) struct ClientInner {
|
||||
base_client: BaseClient,
|
||||
|
||||
/// The Matrix versions the server supports (well-known ones only)
|
||||
server_versions: OnceCell<Box<[MatrixVersion]>>,
|
||||
server_versions: RwLock<Option<Box<[MatrixVersion]>>>,
|
||||
|
||||
/// The unstable features and their on/off state on the server
|
||||
unstable_features: OnceCell<BTreeMap<String, bool>>,
|
||||
unstable_features: RwLock<Option<BTreeMap<String, bool>>>,
|
||||
|
||||
/// Collection of locks individual client methods might want to use, either
|
||||
/// to ensure that only a single call to a method happens at once or to
|
||||
@@ -324,8 +324,8 @@ impl ClientInner {
|
||||
http_client,
|
||||
base_client,
|
||||
locks: Default::default(),
|
||||
server_versions: OnceCell::new_with(server_versions),
|
||||
unstable_features: OnceCell::new_with(unstable_features),
|
||||
server_versions: RwLock::new(server_versions),
|
||||
unstable_features: RwLock::new(unstable_features),
|
||||
typing_notice_times: Default::default(),
|
||||
event_handlers: Default::default(),
|
||||
notification_handlers: Default::default(),
|
||||
@@ -1430,7 +1430,7 @@ impl Client {
|
||||
config,
|
||||
homeserver,
|
||||
access_token.as_deref(),
|
||||
self.server_versions().await?,
|
||||
&self.server_versions().await?,
|
||||
send_progress,
|
||||
)
|
||||
.await
|
||||
@@ -1509,20 +1509,25 @@ impl Client {
|
||||
self.fetch_and_cache_server_capabilities().await
|
||||
}
|
||||
|
||||
pub(crate) async fn server_versions(&self) -> HttpResult<&[MatrixVersion]> {
|
||||
Ok(self
|
||||
.inner
|
||||
.server_versions
|
||||
.get_or_try_init(|| async {
|
||||
let (versions, unstable_features) =
|
||||
self.load_or_fetch_server_capabilities().await?;
|
||||
pub(crate) async fn server_versions(&self) -> HttpResult<Box<[MatrixVersion]>> {
|
||||
if let Some(server_versions) = self.inner.server_versions.read().await.as_ref() {
|
||||
return Ok(server_versions.clone());
|
||||
}
|
||||
|
||||
// Fill the unstable features, while we're at it.
|
||||
self.inner.unstable_features.get_or_init(|| async { unstable_features }).await;
|
||||
let mut guard = self.inner.server_versions.write().await;
|
||||
if let Some(server_versions) = guard.as_ref() {
|
||||
return Ok(server_versions.clone());
|
||||
}
|
||||
|
||||
HttpResult::Ok(versions)
|
||||
})
|
||||
.await?)
|
||||
let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?;
|
||||
|
||||
*guard = Some(versions.clone());
|
||||
drop(guard);
|
||||
|
||||
// Fill the unstable features, while we're at it.
|
||||
*self.inner.unstable_features.write().await = Some(unstable_features);
|
||||
|
||||
Ok(versions)
|
||||
}
|
||||
|
||||
/// Get unstable features from by fetching from the server or the cache.
|
||||
@@ -1539,20 +1544,39 @@ impl Client {
|
||||
/// let msc_x = unstable_features.get("msc_x").unwrap_or(&false);
|
||||
/// # anyhow::Ok(()) };
|
||||
/// ```
|
||||
pub async fn unstable_features(&self) -> HttpResult<&BTreeMap<String, bool>> {
|
||||
Ok(self
|
||||
.inner
|
||||
.unstable_features
|
||||
.get_or_try_init(|| async {
|
||||
let (versions, unstable_features) =
|
||||
self.load_or_fetch_server_capabilities().await?;
|
||||
pub async fn unstable_features(&self) -> HttpResult<BTreeMap<String, bool>> {
|
||||
if let Some(unstable_features) = self.inner.unstable_features.read().await.as_ref() {
|
||||
return Ok(unstable_features.clone());
|
||||
}
|
||||
|
||||
// Fill the versions, while we're at it.
|
||||
self.inner.server_versions.get_or_init(|| async { versions }).await;
|
||||
let mut guard = self.inner.unstable_features.write().await;
|
||||
if let Some(unstable_features) = guard.as_ref() {
|
||||
return Ok(unstable_features.clone());
|
||||
}
|
||||
|
||||
HttpResult::Ok(unstable_features)
|
||||
})
|
||||
.await?)
|
||||
let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?;
|
||||
|
||||
*guard = Some(unstable_features.clone());
|
||||
drop(guard);
|
||||
|
||||
// Fill the versions, while we're at it.
|
||||
*self.inner.server_versions.write().await = Some(versions);
|
||||
|
||||
Ok(unstable_features)
|
||||
}
|
||||
|
||||
/// Empty the server version and unstable features cache.
|
||||
///
|
||||
/// Since the SDK caches server capabilities (versions and unstable
|
||||
/// features), it's possible to have a stale entry in the cache. This
|
||||
/// functions makes it possible to force reset it.
|
||||
pub async fn reset_server_capabilities(&self) -> Result<()> {
|
||||
// Empty the in-memory caches.
|
||||
*self.inner.server_versions.write().await = None;
|
||||
*self.inner.unstable_features.write().await = None;
|
||||
|
||||
// Empty the store cache.
|
||||
Ok(self.store().remove_kv_data(StateStoreDataKey::ServerCapabilities).await?)
|
||||
}
|
||||
|
||||
/// Check whether MSC 4028 is enabled on the homeserver.
|
||||
@@ -2161,8 +2185,8 @@ impl Client {
|
||||
sliding_sync_proxy,
|
||||
self.inner.http_client.clone(),
|
||||
self.inner.base_client.clone_with_in_memory_state_store(),
|
||||
self.inner.server_versions.get().cloned(),
|
||||
self.inner.unstable_features.get().cloned(),
|
||||
self.inner.server_versions.read().await.clone(),
|
||||
self.inner.unstable_features.read().await.clone(),
|
||||
self.inner.respect_login_well_known,
|
||||
self.inner.event_cache.clone(),
|
||||
self.inner.send_queue_data.clone(),
|
||||
@@ -2236,7 +2260,10 @@ pub(crate) mod tests {
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use matrix_sdk_base::RoomState;
|
||||
use matrix_sdk_base::{
|
||||
store::{MemoryStore, StoreConfig},
|
||||
RoomState,
|
||||
};
|
||||
use matrix_sdk_test::{
|
||||
async_test, test_json, JoinedRoomBuilder, StateTestEvent, SyncResponseBuilder,
|
||||
DEFAULT_TEST_ROOM_ID,
|
||||
@@ -2245,8 +2272,8 @@ pub(crate) mod tests {
|
||||
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
use ruma::{
|
||||
events::ignored_user_list::IgnoredUserListEventContent, owned_room_id, room_id, RoomId,
|
||||
UserId,
|
||||
api::MatrixVersion, events::ignored_user_list::IgnoredUserListEventContent, owned_room_id,
|
||||
room_id, RoomId, ServerName, UserId,
|
||||
};
|
||||
use url::Url;
|
||||
use wiremock::{
|
||||
@@ -2609,4 +2636,88 @@ pub(crate) mod tests {
|
||||
Arc::strong_count(&client.unwrap().inner)
|
||||
);
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn test_server_capabilities_caching() {
|
||||
let server = MockServer::start().await;
|
||||
let server_url = server.uri();
|
||||
let domain = server_url.strip_prefix("http://").unwrap();
|
||||
let server_name = <&ServerName>::try_from(domain).unwrap();
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/.well-known/matrix/client"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_raw(
|
||||
test_json::WELL_KNOWN.to_string().replace("HOMESERVER_URL", server_url.as_ref()),
|
||||
"application/json",
|
||||
))
|
||||
.named("well known mock")
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let versions_mock = Mock::given(method("GET"))
|
||||
.and(path("/_matrix/client/versions"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
|
||||
.named("first versions mock")
|
||||
.expect(1)
|
||||
.mount_as_scoped(&server)
|
||||
.await;
|
||||
|
||||
let memory_store = Arc::new(MemoryStore::new());
|
||||
let client = Client::builder()
|
||||
.insecure_server_name_no_tls(server_name)
|
||||
.store_config(StoreConfig::new().state_store(memory_store.clone()))
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(client.server_versions().await.unwrap().len(), 1);
|
||||
|
||||
// This second call hits the in-memory cache.
|
||||
assert!(client
|
||||
.server_versions()
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|version| *version == MatrixVersion::V1_0));
|
||||
|
||||
drop(client);
|
||||
|
||||
let client = Client::builder()
|
||||
.insecure_server_name_no_tls(server_name)
|
||||
.store_config(StoreConfig::new().state_store(memory_store.clone()))
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// This third call hits the on-disk cache.
|
||||
assert_eq!(
|
||||
client.unstable_features().await.unwrap().get("org.matrix.e2e_cross_signing"),
|
||||
Some(&true)
|
||||
);
|
||||
|
||||
drop(versions_mock);
|
||||
server.verify().await;
|
||||
|
||||
// Now, reset the cache, and observe the endpoint being called again once.
|
||||
client.reset_server_capabilities().await.unwrap();
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/_matrix/client/versions"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
|
||||
.expect(1)
|
||||
.named("second versions mock")
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
// Hits network again.
|
||||
assert_eq!(client.server_versions().await.unwrap().len(), 1);
|
||||
// Hits in-memory cache again.
|
||||
assert!(client
|
||||
.server_versions()
|
||||
.await
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|version| *version == MatrixVersion::V1_0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,13 +130,13 @@ impl MatrixAuth {
|
||||
.try_into_http_request::<Vec<u8>>(
|
||||
homeserver.as_str(),
|
||||
SendAccessToken::None,
|
||||
server_versions,
|
||||
&server_versions,
|
||||
)
|
||||
} else {
|
||||
sso_login::v3::Request::new(redirect_url.to_owned()).try_into_http_request::<Vec<u8>>(
|
||||
homeserver.as_str(),
|
||||
SendAccessToken::None,
|
||||
server_versions,
|
||||
&server_versions,
|
||||
)
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user