From 1bc044349e54a680e36043f182d391d853d9640d Mon Sep 17 00:00:00 2001 From: Benjamin Bouvier Date: Mon, 8 Jul 2024 15:45:41 +0200 Subject: [PATCH] sdk: use a single field store both the server versions and the unstable features The only thing is that our client builder allows setting only the server versions, so this means the new field has to contain two `Option`al fields. This causes a bit of churn, but isn't too bad in the end. --- crates/matrix-sdk/src/client/builder.rs | 13 ++- crates/matrix-sdk/src/client/mod.rs | 126 ++++++++++++------------ 2 files changed, 72 insertions(+), 67 deletions(-) diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index 2d4100be1..61854d716 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -37,8 +37,9 @@ use crate::http_client::HttpSettings; #[cfg(feature = "experimental-oidc")] use crate::oidc::OidcCtx; use crate::{ - authentication::AuthCtx, config::RequestConfig, error::RumaApiError, http_client::HttpClient, - send_queue::SendQueueData, HttpError, IdParseError, + authentication::AuthCtx, client::ClientServerCapabilities, config::RequestConfig, + error::RumaApiError, http_client::HttpClient, send_queue::SendQueueData, HttpError, + IdParseError, }; /// Builder that allows creating and configuring various parts of a [`Client`]. @@ -477,6 +478,11 @@ impl ClientBuilder { // Enable the send queue by default. let send_queue = Arc::new(SendQueueData::new(true)); + let server_capabilities = ClientServerCapabilities { + server_versions: self.server_versions, + unstable_features: None, + }; + let event_cache = OnceCell::new(); let inner = ClientInner::new( auth_ctx, @@ -485,8 +491,7 @@ impl ClientBuilder { sliding_sync_proxy, http_client, base_client, - self.server_versions, - None, + server_capabilities, self.respect_login_well_known, event_cache, send_queue, diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 4788503c8..b4edd53ef 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -238,11 +238,9 @@ pub(crate) struct ClientInner { /// User session data. base_client: BaseClient, - /// The Matrix versions the server supports (well-known ones only) - server_versions: RwLock>>, - - /// The unstable features and their on/off state on the server - unstable_features: RwLock>>, + /// Server capabilities, either prefilled during building or fetched from + /// the server. + server_capabilities: RwLock, /// Collection of locks individual client methods might want to use, either /// to ensure that only a single call to a method happens at once or to @@ -309,8 +307,7 @@ impl ClientInner { #[cfg(feature = "experimental-sliding-sync")] sliding_sync_proxy: Option, http_client: HttpClient, base_client: BaseClient, - server_versions: Option>, - unstable_features: Option>, + server_capabilities: ClientServerCapabilities, respect_login_well_known: bool, event_cache: OnceCell, send_queue: Arc, @@ -324,8 +321,7 @@ impl ClientInner { http_client, base_client, locks: Default::default(), - server_versions: RwLock::new(server_versions), - unstable_features: RwLock::new(unstable_features), + server_capabilities: RwLock::new(server_capabilities), typing_notice_times: Default::default(), event_handlers: Default::default(), notification_handlers: Default::default(), @@ -1444,7 +1440,8 @@ impl Client { .send(SessionChange::UnknownToken { soft_logout: *soft_logout }); } - async fn fetch_and_cache_server_capabilities( + /// Fetches server capabilities from network; no caching. + async fn fetch_server_capabilities( &self, ) -> HttpResult<(Box<[MatrixVersion]>, BTreeMap)> { let resp = self @@ -1466,26 +1463,11 @@ impl Client { versions.push(MatrixVersion::V1_0); } - let unstable_features = resp.unstable_features; - - { - // Attempt to cache the result. - let encoded = ServerCapabilities::new(&versions, unstable_features.clone()); - if let Err(err) = self - .store() - .set_kv_data( - StateStoreDataKey::ServerCapabilities, - StateStoreDataValue::ServerCapabilities(encoded), - ) - .await - { - warn!("error when caching server capabilities: {err}"); - } - } - - Ok((versions.into(), unstable_features)) + Ok((versions.into(), resp.unstable_features)) } + /// Load server capabilities from storage, or fetch them from network and + /// cache them. async fn load_or_fetch_server_capabilities( &self, ) -> HttpResult<(Box<[MatrixVersion]>, BTreeMap)> { @@ -1506,28 +1488,54 @@ impl Client { } } - self.fetch_and_cache_server_capabilities().await - } + let (versions, unstable_features) = self.fetch_server_capabilities().await?; - pub(crate) async fn server_versions(&self) -> HttpResult> { - if let Some(server_versions) = self.inner.server_versions.read().await.as_ref() { - return Ok(server_versions.clone()); + // Attempt to cache the result in storage. + { + let encoded = ServerCapabilities::new(&versions, unstable_features.clone()); + if let Err(err) = self + .store() + .set_kv_data( + StateStoreDataKey::ServerCapabilities, + StateStoreDataValue::ServerCapabilities(encoded), + ) + .await + { + warn!("error when caching server capabilities: {err}"); + } } - let mut guard = self.inner.server_versions.write().await; - if let Some(server_versions) = guard.as_ref() { - return Ok(server_versions.clone()); + Ok((versions, unstable_features)) + } + + async fn get_or_load_and_cache_server_capabilities< + T, + F: Fn(&ClientServerCapabilities) -> Option, + >( + &self, + f: F, + ) -> HttpResult { + let caps = &self.inner.server_capabilities; + if let Some(val) = f(&*caps.read().await) { + return Ok(val); + } + + let mut guard = caps.write().await; + if let Some(val) = f(&guard) { + return Ok(val); } let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?; - *guard = Some(versions.clone()); - drop(guard); + guard.server_versions = Some(versions); + guard.unstable_features = Some(unstable_features); - // Fill the unstable features, while we're at it. - *self.inner.unstable_features.write().await = Some(unstable_features); + // SAFETY: both fields were set above, so the function will always return some. + Ok(f(&guard).unwrap()) + } - Ok(versions) + pub(crate) async fn server_versions(&self) -> HttpResult> { + self.get_or_load_and_cache_server_capabilities(|caps| caps.server_versions.clone()).await } /// Get unstable features from by fetching from the server or the cache. @@ -1545,24 +1553,7 @@ impl Client { /// # anyhow::Ok(()) }; /// ``` pub async fn unstable_features(&self) -> HttpResult> { - if let Some(unstable_features) = self.inner.unstable_features.read().await.as_ref() { - return Ok(unstable_features.clone()); - } - - let mut guard = self.inner.unstable_features.write().await; - if let Some(unstable_features) = guard.as_ref() { - return Ok(unstable_features.clone()); - } - - let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?; - - *guard = Some(unstable_features.clone()); - drop(guard); - - // Fill the versions, while we're at it. - *self.inner.server_versions.write().await = Some(versions); - - Ok(unstable_features) + self.get_or_load_and_cache_server_capabilities(|caps| caps.unstable_features.clone()).await } /// Empty the server version and unstable features cache. @@ -1572,8 +1563,9 @@ impl Client { /// functions makes it possible to force reset it. pub async fn reset_server_capabilities(&self) -> Result<()> { // Empty the in-memory caches. - *self.inner.server_versions.write().await = None; - *self.inner.unstable_features.write().await = None; + let mut guard = self.inner.server_capabilities.write().await; + guard.server_versions = None; + guard.unstable_features = None; // Empty the store cache. Ok(self.store().remove_kv_data(StateStoreDataKey::ServerCapabilities).await?) @@ -2185,8 +2177,7 @@ impl Client { sliding_sync_proxy, self.inner.http_client.clone(), self.inner.base_client.clone_with_in_memory_state_store(), - self.inner.server_versions.read().await.clone(), - self.inner.unstable_features.read().await.clone(), + self.inner.server_capabilities.read().await.clone(), self.inner.respect_login_well_known, self.inner.event_cache.clone(), self.inner.send_queue_data.clone(), @@ -2254,6 +2245,15 @@ impl WeakClient { } } +#[derive(Clone)] +struct ClientServerCapabilities { + /// The Matrix versions the server supports (well-known ones only). + server_versions: Option>, + + /// The unstable features and their on/off state on the server. + unstable_features: Option>, +} + // The http mocking library is not supported for wasm32 #[cfg(all(test, not(target_arch = "wasm32")))] pub(crate) mod tests {