mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-07 15:33:45 -04:00
Allow to limit the number of concurrent requests made by the sdk (#3625)
Add a new `max_concurrent_requests` parameter in the `RequestConfig` limits the number of http(s) requests the internal sdk client issues concurrently (if > 0). The default behavior is the same as before: there is no limit on concurrent requests issued. This is especially useful for resource constrained platforms (e.g. mobile platforms), and if your pattern might lead to issuing many requests at the same time (like downloading and caching all avatars at startup). - [x] Public API changes documented in changelogs (optional) Signed-off-by: Benjamin Kampmann <ben.kampmann@gmail.com>
This commit is contained in:
committed by
GitHub
parent
aaccfdfea5
commit
d49cb54b67
@@ -22,6 +22,7 @@ Breaking changes:
|
||||
|
||||
Additions:
|
||||
|
||||
- new `RequestConfig.max_concurrent_requests` which allows to limit the maximum number of concurrent requests the internal HTTP client issues (all others have to wait until the number drops below that threshold again)
|
||||
- Expose new method `Client::Oidc::login_with_qr_code()`.
|
||||
([#3466](https://github.com/matrix-org/matrix-rust-sdk/pull/3466))
|
||||
- Add the `ClientBuilder::add_root_certificates()` method which re-exposes the
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::{
|
||||
fmt::{self, Debug},
|
||||
num::NonZeroUsize,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
@@ -44,18 +45,21 @@ pub struct RequestConfig {
|
||||
pub(crate) timeout: Duration,
|
||||
pub(crate) retry_limit: Option<u64>,
|
||||
pub(crate) retry_timeout: Option<Duration>,
|
||||
pub(crate) max_concurrent_requests: Option<NonZeroUsize>,
|
||||
pub(crate) force_auth: bool,
|
||||
}
|
||||
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
impl Debug for RequestConfig {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let Self { timeout, retry_limit, retry_timeout, force_auth } = self;
|
||||
let Self { timeout, retry_limit, retry_timeout, force_auth, max_concurrent_requests } =
|
||||
self;
|
||||
|
||||
let mut res = fmt.debug_struct("RequestConfig");
|
||||
res.field("timeout", timeout)
|
||||
.maybe_field("retry_limit", retry_limit)
|
||||
.maybe_field("retry_timeout", retry_timeout);
|
||||
.maybe_field("retry_timeout", retry_timeout)
|
||||
.maybe_field("max_concurrent_requests", max_concurrent_requests);
|
||||
|
||||
if *force_auth {
|
||||
res.field("force_auth", &true);
|
||||
@@ -71,6 +75,7 @@ impl Default for RequestConfig {
|
||||
timeout: DEFAULT_REQUEST_TIMEOUT,
|
||||
retry_limit: Default::default(),
|
||||
retry_timeout: Default::default(),
|
||||
max_concurrent_requests: Default::default(),
|
||||
force_auth: false,
|
||||
}
|
||||
}
|
||||
@@ -106,6 +111,15 @@ impl RequestConfig {
|
||||
self
|
||||
}
|
||||
|
||||
/// The total limit of request that are pending or run concurrently.
|
||||
/// Any additional request beyond that number will be waiting until another
|
||||
/// concurrent requests finished. Requests are queued fairly.
|
||||
#[must_use]
|
||||
pub fn max_concurrent_requests(mut self, limit: Option<NonZeroUsize>) -> Self {
|
||||
self.max_concurrent_requests = limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the timeout duration for all HTTP requests.
|
||||
#[must_use]
|
||||
pub fn timeout(mut self, timeout: Duration) -> Self {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use std::{
|
||||
any::type_name,
|
||||
fmt::Debug,
|
||||
num::NonZeroUsize,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
@@ -30,6 +31,7 @@ use ruma::api::{
|
||||
error::{FromHttpResponseError, IntoHttpError},
|
||||
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
};
|
||||
use tokio::sync::{Semaphore, SemaphorePermit};
|
||||
use tracing::{debug, field::debug, instrument, trace};
|
||||
|
||||
use crate::{config::RequestConfig, error::HttpError};
|
||||
@@ -48,16 +50,48 @@ pub(crate) use native::HttpSettings;
|
||||
|
||||
pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct MaybeSemaphore(Arc<Option<Semaphore>>);
|
||||
|
||||
#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
|
||||
struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
|
||||
|
||||
impl MaybeSemaphore {
|
||||
fn new(max: Option<NonZeroUsize>) -> Self {
|
||||
let inner = max.map(|i| Semaphore::new(i.into()));
|
||||
MaybeSemaphore(Arc::new(inner))
|
||||
}
|
||||
|
||||
async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
|
||||
match self.0.as_ref() {
|
||||
Some(inner) => {
|
||||
// This can only ever error if the semaphore was closed,
|
||||
// which we never do, so we can safely ignore any error case
|
||||
MaybeSemaphorePermit(inner.acquire().await.ok())
|
||||
}
|
||||
None => MaybeSemaphorePermit(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct HttpClient {
|
||||
pub(crate) inner: reqwest::Client,
|
||||
pub(crate) request_config: RequestConfig,
|
||||
concurrent_request_semaphore: MaybeSemaphore,
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
|
||||
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
|
||||
HttpClient {
|
||||
inner,
|
||||
request_config,
|
||||
concurrent_request_semaphore: MaybeSemaphore::new(
|
||||
request_config.max_concurrent_requests,
|
||||
),
|
||||
next_request_id: AtomicU64::new(0).into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_request_id(&self) -> String {
|
||||
@@ -184,6 +218,9 @@ impl HttpClient {
|
||||
request
|
||||
};
|
||||
|
||||
// will be automatically dropped at the end of this function
|
||||
let _handle = self.concurrent_request_semaphore.acquire().await;
|
||||
|
||||
debug!("Sending request");
|
||||
|
||||
// There's a bunch of state in send_request, factor out a pinned inner
|
||||
@@ -259,3 +296,111 @@ impl tower::Service<http_old::Request<Bytes>> for HttpClient {
|
||||
Box::pin(fut)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(target_arch = "wasm32")))]
|
||||
mod tests {
|
||||
use std::{
|
||||
num::NonZeroUsize,
|
||||
sync::{
|
||||
atomic::{AtomicU8, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use matrix_sdk_test::{async_test, test_json};
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
Mock, Request, ResponseTemplate,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
http_client::RequestConfig,
|
||||
test_utils::{set_client_session, test_client_builder_with_server},
|
||||
};
|
||||
|
||||
#[async_test]
|
||||
async fn ensure_concurrent_request_limit_is_observed() {
|
||||
let (client_builder, server) = test_client_builder_with_server().await;
|
||||
let client = client_builder
|
||||
.request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
set_client_session(&client).await;
|
||||
|
||||
let counter = Arc::new(AtomicU8::new(0));
|
||||
let inner_counter = counter.clone();
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/_matrix/client/versions"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("_matrix/client/r0/account/whoami"))
|
||||
.respond_with(move |_req: &Request| {
|
||||
inner_counter.fetch_add(1, Ordering::SeqCst);
|
||||
// we stall the requests
|
||||
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
|
||||
})
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let bg_task = tokio::spawn(async move {
|
||||
futures_util::future::join_all((0..10).map(|_| client.whoami())).await
|
||||
});
|
||||
|
||||
// give it some time to issue the requests
|
||||
tokio::time::sleep(Duration::from_millis(300)).await;
|
||||
|
||||
assert_eq!(
|
||||
counter.load(Ordering::SeqCst),
|
||||
5,
|
||||
"More requests passed than the limit we configured"
|
||||
);
|
||||
bg_task.abort();
|
||||
}
|
||||
|
||||
#[async_test]
|
||||
async fn ensure_no_max_concurrent_request_does_not_limit() {
|
||||
let (client_builder, server) = test_client_builder_with_server().await;
|
||||
let client = client_builder
|
||||
.request_config(RequestConfig::default().max_concurrent_requests(None))
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
set_client_session(&client).await;
|
||||
|
||||
let counter = Arc::new(AtomicU8::new(0));
|
||||
let inner_counter = counter.clone();
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/_matrix/client/versions"))
|
||||
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("_matrix/client/r0/account/whoami"))
|
||||
.respond_with(move |_req: &Request| {
|
||||
inner_counter.fetch_add(1, Ordering::SeqCst);
|
||||
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
|
||||
})
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let bg_task = tokio::spawn(async move {
|
||||
futures_util::future::join_all((0..254).map(|_| client.whoami())).await
|
||||
});
|
||||
|
||||
// give it some time to issue the requests
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
|
||||
bg_task.abort();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user