diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index 2d4100be1..61854d716 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -37,8 +37,9 @@ use crate::http_client::HttpSettings; #[cfg(feature = "experimental-oidc")] use crate::oidc::OidcCtx; use crate::{ - authentication::AuthCtx, config::RequestConfig, error::RumaApiError, http_client::HttpClient, - send_queue::SendQueueData, HttpError, IdParseError, + authentication::AuthCtx, client::ClientServerCapabilities, config::RequestConfig, + error::RumaApiError, http_client::HttpClient, send_queue::SendQueueData, HttpError, + IdParseError, }; /// Builder that allows creating and configuring various parts of a [`Client`]. @@ -477,6 +478,11 @@ impl ClientBuilder { // Enable the send queue by default. let send_queue = Arc::new(SendQueueData::new(true)); + let server_capabilities = ClientServerCapabilities { + server_versions: self.server_versions, + unstable_features: None, + }; + let event_cache = OnceCell::new(); let inner = ClientInner::new( auth_ctx, @@ -485,8 +491,7 @@ impl ClientBuilder { sliding_sync_proxy, http_client, base_client, - self.server_versions, - None, + server_capabilities, self.respect_login_well_known, event_cache, send_queue, diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 4788503c8..b4edd53ef 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -238,11 +238,9 @@ pub(crate) struct ClientInner { /// User session data. base_client: BaseClient, - /// The Matrix versions the server supports (well-known ones only) - server_versions: RwLock>>, - - /// The unstable features and their on/off state on the server - unstable_features: RwLock>>, + /// Server capabilities, either prefilled during building or fetched from + /// the server. + server_capabilities: RwLock, /// Collection of locks individual client methods might want to use, either /// to ensure that only a single call to a method happens at once or to @@ -309,8 +307,7 @@ impl ClientInner { #[cfg(feature = "experimental-sliding-sync")] sliding_sync_proxy: Option, http_client: HttpClient, base_client: BaseClient, - server_versions: Option>, - unstable_features: Option>, + server_capabilities: ClientServerCapabilities, respect_login_well_known: bool, event_cache: OnceCell, send_queue: Arc, @@ -324,8 +321,7 @@ impl ClientInner { http_client, base_client, locks: Default::default(), - server_versions: RwLock::new(server_versions), - unstable_features: RwLock::new(unstable_features), + server_capabilities: RwLock::new(server_capabilities), typing_notice_times: Default::default(), event_handlers: Default::default(), notification_handlers: Default::default(), @@ -1444,7 +1440,8 @@ impl Client { .send(SessionChange::UnknownToken { soft_logout: *soft_logout }); } - async fn fetch_and_cache_server_capabilities( + /// Fetches server capabilities from network; no caching. + async fn fetch_server_capabilities( &self, ) -> HttpResult<(Box<[MatrixVersion]>, BTreeMap)> { let resp = self @@ -1466,26 +1463,11 @@ impl Client { versions.push(MatrixVersion::V1_0); } - let unstable_features = resp.unstable_features; - - { - // Attempt to cache the result. - let encoded = ServerCapabilities::new(&versions, unstable_features.clone()); - if let Err(err) = self - .store() - .set_kv_data( - StateStoreDataKey::ServerCapabilities, - StateStoreDataValue::ServerCapabilities(encoded), - ) - .await - { - warn!("error when caching server capabilities: {err}"); - } - } - - Ok((versions.into(), unstable_features)) + Ok((versions.into(), resp.unstable_features)) } + /// Load server capabilities from storage, or fetch them from network and + /// cache them. async fn load_or_fetch_server_capabilities( &self, ) -> HttpResult<(Box<[MatrixVersion]>, BTreeMap)> { @@ -1506,28 +1488,54 @@ impl Client { } } - self.fetch_and_cache_server_capabilities().await - } + let (versions, unstable_features) = self.fetch_server_capabilities().await?; - pub(crate) async fn server_versions(&self) -> HttpResult> { - if let Some(server_versions) = self.inner.server_versions.read().await.as_ref() { - return Ok(server_versions.clone()); + // Attempt to cache the result in storage. + { + let encoded = ServerCapabilities::new(&versions, unstable_features.clone()); + if let Err(err) = self + .store() + .set_kv_data( + StateStoreDataKey::ServerCapabilities, + StateStoreDataValue::ServerCapabilities(encoded), + ) + .await + { + warn!("error when caching server capabilities: {err}"); + } } - let mut guard = self.inner.server_versions.write().await; - if let Some(server_versions) = guard.as_ref() { - return Ok(server_versions.clone()); + Ok((versions, unstable_features)) + } + + async fn get_or_load_and_cache_server_capabilities< + T, + F: Fn(&ClientServerCapabilities) -> Option, + >( + &self, + f: F, + ) -> HttpResult { + let caps = &self.inner.server_capabilities; + if let Some(val) = f(&*caps.read().await) { + return Ok(val); + } + + let mut guard = caps.write().await; + if let Some(val) = f(&guard) { + return Ok(val); } let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?; - *guard = Some(versions.clone()); - drop(guard); + guard.server_versions = Some(versions); + guard.unstable_features = Some(unstable_features); - // Fill the unstable features, while we're at it. - *self.inner.unstable_features.write().await = Some(unstable_features); + // SAFETY: both fields were set above, so the function will always return some. + Ok(f(&guard).unwrap()) + } - Ok(versions) + pub(crate) async fn server_versions(&self) -> HttpResult> { + self.get_or_load_and_cache_server_capabilities(|caps| caps.server_versions.clone()).await } /// Get unstable features from by fetching from the server or the cache. @@ -1545,24 +1553,7 @@ impl Client { /// # anyhow::Ok(()) }; /// ``` pub async fn unstable_features(&self) -> HttpResult> { - if let Some(unstable_features) = self.inner.unstable_features.read().await.as_ref() { - return Ok(unstable_features.clone()); - } - - let mut guard = self.inner.unstable_features.write().await; - if let Some(unstable_features) = guard.as_ref() { - return Ok(unstable_features.clone()); - } - - let (versions, unstable_features) = self.load_or_fetch_server_capabilities().await?; - - *guard = Some(unstable_features.clone()); - drop(guard); - - // Fill the versions, while we're at it. - *self.inner.server_versions.write().await = Some(versions); - - Ok(unstable_features) + self.get_or_load_and_cache_server_capabilities(|caps| caps.unstable_features.clone()).await } /// Empty the server version and unstable features cache. @@ -1572,8 +1563,9 @@ impl Client { /// functions makes it possible to force reset it. pub async fn reset_server_capabilities(&self) -> Result<()> { // Empty the in-memory caches. - *self.inner.server_versions.write().await = None; - *self.inner.unstable_features.write().await = None; + let mut guard = self.inner.server_capabilities.write().await; + guard.server_versions = None; + guard.unstable_features = None; // Empty the store cache. Ok(self.store().remove_kv_data(StateStoreDataKey::ServerCapabilities).await?) @@ -2185,8 +2177,7 @@ impl Client { sliding_sync_proxy, self.inner.http_client.clone(), self.inner.base_client.clone_with_in_memory_state_store(), - self.inner.server_versions.read().await.clone(), - self.inner.unstable_features.read().await.clone(), + self.inner.server_capabilities.read().await.clone(), self.inner.respect_login_well_known, self.inner.event_cache.clone(), self.inner.send_queue_data.clone(), @@ -2254,6 +2245,15 @@ impl WeakClient { } } +#[derive(Clone)] +struct ClientServerCapabilities { + /// The Matrix versions the server supports (well-known ones only). + server_versions: Option>, + + /// The unstable features and their on/off state on the server. + unstable_features: Option>, +} + // The http mocking library is not supported for wasm32 #[cfg(all(test, not(target_arch = "wasm32")))] pub(crate) mod tests {