mirror of
https://github.com/matrix-org/matrix-rust-sdk.git
synced 2026-05-18 21:52:30 -04:00
sdk: Move target-specific HTTP code into new submodules
This commit is contained in:
committed by
Jonas Platte
parent
4c1a351ead
commit
5273fa6f42
@@ -1,512 +0,0 @@
|
||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
any::type_name,
|
||||
fmt::Debug,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytesize::ByteSize;
|
||||
use eyeball::shared::Observable as SharedObservable;
|
||||
use ruma::{
|
||||
api::{
|
||||
error::{FromHttpResponseError, IntoHttpError},
|
||||
AuthScheme, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingRequestAppserviceExt,
|
||||
SendAccessToken,
|
||||
},
|
||||
UserId,
|
||||
};
|
||||
use tracing::{debug, field::debug, instrument, trace};
|
||||
|
||||
use crate::{config::RequestConfig, error::HttpError};
|
||||
|
||||
pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct HttpClient {
|
||||
pub(crate) inner: reqwest::Client,
|
||||
pub(crate) request_config: RequestConfig,
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
|
||||
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
|
||||
}
|
||||
|
||||
fn get_request_id(&self) -> String {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
|
||||
format!("REQ-{request_id}")
|
||||
}
|
||||
|
||||
fn serialize_request<R>(
|
||||
&self,
|
||||
request: R,
|
||||
config: RequestConfig,
|
||||
homeserver: String,
|
||||
access_token: Option<&str>,
|
||||
user_id: Option<&UserId>,
|
||||
server_versions: &[MatrixVersion],
|
||||
) -> Result<http::Request<Bytes>, IntoHttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
{
|
||||
trace!(request_type = type_name::<R>(), "Serializing request");
|
||||
|
||||
// We can't assert the identity without a user_id.
|
||||
let request = if let Some((access_token, user_id)) =
|
||||
access_token.filter(|_| config.assert_identity).zip(user_id)
|
||||
{
|
||||
request.try_into_http_request_with_user_id::<BytesMut>(
|
||||
&homeserver,
|
||||
SendAccessToken::Always(access_token),
|
||||
user_id,
|
||||
server_versions,
|
||||
)?
|
||||
} else {
|
||||
let send_access_token = match access_token {
|
||||
Some(access_token) => {
|
||||
if config.force_auth {
|
||||
SendAccessToken::Always(access_token)
|
||||
} else {
|
||||
SendAccessToken::IfRequired(access_token)
|
||||
}
|
||||
}
|
||||
None => SendAccessToken::None,
|
||||
};
|
||||
|
||||
request.try_into_http_request::<BytesMut>(
|
||||
&homeserver,
|
||||
send_access_token,
|
||||
server_versions,
|
||||
)?
|
||||
};
|
||||
|
||||
let request = request.map(|body| body.freeze());
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
async fn send_request<R>(
|
||||
&self,
|
||||
request: http::Request<Bytes>,
|
||||
config: RequestConfig,
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<R::IncomingResponse, HttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<R::EndpointError>>,
|
||||
{
|
||||
// There's a bunch of state in send_request_with_retries, factor
|
||||
// out a pinned inner future to reduce this size of futures that
|
||||
// await this function.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
let response =
|
||||
Box::pin(self.send_request_with_retries::<R>(config, request, send_progress)).await?;
|
||||
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let response = {
|
||||
// Silence unused warning
|
||||
_ = (config, send_progress);
|
||||
|
||||
let request = reqwest::Request::try_from(request)?;
|
||||
let response = response_to_http_response(self.inner.execute(request).await?).await?;
|
||||
|
||||
let status_code = response.status();
|
||||
let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
|
||||
tracing::Span::current()
|
||||
.record("status", status_code.as_u16())
|
||||
.record("response_size", response_size.to_string_as(true));
|
||||
|
||||
R::IncomingResponse::try_from_http_response(response)?
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
async fn send_request_with_retries<R>(
|
||||
&self,
|
||||
config: RequestConfig,
|
||||
request: http::Request<Bytes>,
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<R::IncomingResponse, HttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<R::EndpointError>>,
|
||||
{
|
||||
use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
|
||||
use ruma::api::client::error::{
|
||||
ErrorBody as ClientApiErrorBody, ErrorKind as ClientApiErrorKind,
|
||||
};
|
||||
|
||||
use crate::RumaApiError;
|
||||
|
||||
let backoff =
|
||||
ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() };
|
||||
let retry_count = AtomicU64::new(1);
|
||||
|
||||
let send_request = || {
|
||||
let send_progress = send_progress.clone();
|
||||
async {
|
||||
let stop = if let Some(retry_limit) = config.retry_limit {
|
||||
retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Turn errors into permanent errors when the retry limit is reached
|
||||
let error_type = if stop {
|
||||
RetryError::Permanent
|
||||
} else {
|
||||
|err: HttpError| {
|
||||
if let Some(api_error) = err.as_ruma_api_error() {
|
||||
let status_code = match api_error {
|
||||
RumaApiError::ClientApi(e) => match e.body {
|
||||
ClientApiErrorBody::Standard {
|
||||
kind: ClientApiErrorKind::LimitExceeded { retry_after_ms },
|
||||
..
|
||||
} => {
|
||||
return RetryError::Transient {
|
||||
err,
|
||||
retry_after: retry_after_ms,
|
||||
};
|
||||
}
|
||||
_ => Some(e.status_code),
|
||||
},
|
||||
RumaApiError::Uiaa(_) => None,
|
||||
RumaApiError::Other(e) => Some(e.status_code),
|
||||
};
|
||||
|
||||
if let Some(status_code) = status_code {
|
||||
if status_code.is_server_error() {
|
||||
return RetryError::Transient { err, retry_after: None };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RetryError::Permanent(err)
|
||||
}
|
||||
};
|
||||
|
||||
let response = send_request(
|
||||
&self.inner,
|
||||
clone_request(&request),
|
||||
config.timeout,
|
||||
send_progress,
|
||||
)
|
||||
.await
|
||||
.map_err(error_type)?;
|
||||
|
||||
let status_code = response.status();
|
||||
let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
|
||||
tracing::Span::current()
|
||||
.record("status", status_code.as_u16())
|
||||
.record("response_size", response_size.to_string_as(true));
|
||||
|
||||
R::IncomingResponse::try_from_http_response(response)
|
||||
.map_err(|e| error_type(HttpError::from(e)))
|
||||
}
|
||||
};
|
||||
|
||||
retry::<_, HttpError, _, _, _>(backoff, send_request).await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(
|
||||
skip(self, access_token, config, request, user_id, send_progress),
|
||||
fields(
|
||||
config,
|
||||
path,
|
||||
user_id,
|
||||
request_size,
|
||||
request_body,
|
||||
request_id,
|
||||
status,
|
||||
response_size,
|
||||
)
|
||||
)]
|
||||
pub async fn send<R>(
|
||||
&self,
|
||||
request: R,
|
||||
config: Option<RequestConfig>,
|
||||
homeserver: String,
|
||||
access_token: Option<&str>,
|
||||
user_id: Option<&UserId>,
|
||||
server_versions: &[MatrixVersion],
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<R::IncomingResponse, HttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<R::EndpointError>>,
|
||||
{
|
||||
let config = match config {
|
||||
Some(config) => config,
|
||||
None => self.request_config,
|
||||
};
|
||||
|
||||
// Keep some local variables in a separate scope so the compiler doesn't include
|
||||
// them in the future type. https://github.com/rust-lang/rust/issues/57478
|
||||
let request = {
|
||||
let request_id = self.get_request_id();
|
||||
let span = tracing::Span::current();
|
||||
|
||||
// At this point in the code, the config isn't behind an Option anymore, that's
|
||||
// why we record it here, instead of in the #[instrument] macro.
|
||||
span.record("config", debug(config)).record("request_id", request_id);
|
||||
|
||||
// The user ID is only used if we're an app-service. Only log the user_id if
|
||||
// it's `Some` and if assert_identity is set.
|
||||
if config.assert_identity {
|
||||
span.record("user_id", user_id.map(debug));
|
||||
}
|
||||
|
||||
let auth_scheme = R::METADATA.authentication;
|
||||
if !matches!(auth_scheme, AuthScheme::AccessToken | AuthScheme::None) {
|
||||
return Err(HttpError::NotClientRequest);
|
||||
}
|
||||
|
||||
let request = self.serialize_request(
|
||||
request,
|
||||
config,
|
||||
homeserver,
|
||||
access_token,
|
||||
user_id,
|
||||
server_versions,
|
||||
)?;
|
||||
|
||||
let request_size = ByteSize(request.body().len().try_into().unwrap_or(u64::MAX));
|
||||
span.record("request_size", request_size.to_string_as(true));
|
||||
|
||||
// Since sliding sync is experimental, and the proxy might not do what we expect
|
||||
// it to do given a specific request body, it's useful to log the
|
||||
// request body here. This doesn't contain any personal information.
|
||||
// TODO: Remove this once sliding sync isn't experimental anymore.
|
||||
#[cfg(feature = "experimental-sliding-sync")]
|
||||
if type_name::<R>() == "ruma_client_api::sync::sync_events::v4::Request" {
|
||||
span.record("request_body", debug(request.body()));
|
||||
span.record("path", request.uri().path_and_query().map(|p| p.as_str()));
|
||||
} else {
|
||||
span.record("path", request.uri().path());
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "experimental-sliding-sync"))]
|
||||
span.record("path", request.uri().path());
|
||||
|
||||
request
|
||||
};
|
||||
|
||||
debug!("Sending request");
|
||||
match self.send_request::<R>(request, config, send_progress).await {
|
||||
Ok(response) => {
|
||||
debug!("Got response");
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Error while sending request: {e:?}");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct HttpSettings {
|
||||
pub(crate) disable_ssl_verification: bool,
|
||||
pub(crate) proxy: Option<String>,
|
||||
pub(crate) user_agent: Option<String>,
|
||||
pub(crate) timeout: Duration,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl Default for HttpSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
disable_ssl_verification: false,
|
||||
proxy: None,
|
||||
user_agent: None,
|
||||
timeout: DEFAULT_REQUEST_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl HttpSettings {
|
||||
/// Build a client with the specified configuration.
|
||||
pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
|
||||
let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
|
||||
let mut http_client =
|
||||
reqwest::Client::builder().user_agent(user_agent).timeout(self.timeout);
|
||||
|
||||
if self.disable_ssl_verification {
|
||||
http_client = http_client.danger_accept_invalid_certs(true)
|
||||
}
|
||||
|
||||
if let Some(p) = &self.proxy {
|
||||
http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
|
||||
}
|
||||
|
||||
Ok(http_client.build()?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress of sending or receiving a payload.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct TransmissionProgress {
|
||||
/// How many bytes were already transferred.
|
||||
pub current: usize,
|
||||
/// How many bytes there are in total.
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
// Clones all request parts except the extensions which can't be cloned.
|
||||
// See also https://github.com/hyperium/http/issues/395
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn clone_request(request: &http::Request<Bytes>) -> http::Request<Bytes> {
|
||||
let mut builder = http::Request::builder()
|
||||
.version(request.version())
|
||||
.method(request.method())
|
||||
.uri(request.uri());
|
||||
*builder.headers_mut().unwrap() = request.headers().clone();
|
||||
builder.body(request.body().clone()).unwrap()
|
||||
}
|
||||
|
||||
async fn response_to_http_response(
|
||||
mut response: reqwest::Response,
|
||||
) -> Result<http::Response<Bytes>, reqwest::Error> {
|
||||
let status = response.status();
|
||||
|
||||
let mut http_builder = http::Response::builder().status(status);
|
||||
let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
|
||||
|
||||
for (k, v) in response.headers_mut().drain() {
|
||||
if let Some(key) = k {
|
||||
headers.insert(key, v);
|
||||
}
|
||||
}
|
||||
|
||||
let body = response.bytes().await?;
|
||||
|
||||
Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
async fn send_request(
|
||||
client: &reqwest::Client,
|
||||
request: http::Request<Bytes>,
|
||||
timeout: Duration,
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<http::Response<Bytes>, HttpError> {
|
||||
use std::convert::Infallible;
|
||||
|
||||
use futures_util::stream;
|
||||
|
||||
let request = {
|
||||
let mut request = if send_progress.subscriber_count() != 0 {
|
||||
send_progress.update(|p| p.total += request.body().len());
|
||||
reqwest::Request::try_from(request.map(|body| {
|
||||
let chunks = stream::iter(BytesChunks::new(body, 8192).map(
|
||||
move |chunk| -> Result<_, Infallible> {
|
||||
send_progress.update(|p| p.current += chunk.len());
|
||||
Ok(chunk)
|
||||
},
|
||||
));
|
||||
reqwest::Body::wrap_stream(chunks)
|
||||
}))?
|
||||
} else {
|
||||
reqwest::Request::try_from(request)?
|
||||
};
|
||||
|
||||
*request.timeout_mut() = Some(timeout);
|
||||
request
|
||||
};
|
||||
|
||||
let response = client.execute(request).await?;
|
||||
Ok(response_to_http_response(response).await?)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
struct BytesChunks {
|
||||
bytes: Bytes,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl BytesChunks {
|
||||
fn new(bytes: Bytes, size: usize) -> Self {
|
||||
assert_ne!(size, 0);
|
||||
Self { bytes, size }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl Iterator for BytesChunks {
|
||||
type Item = Bytes;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use std::mem;
|
||||
|
||||
if self.bytes.is_empty() {
|
||||
None
|
||||
} else if self.bytes.len() < self.size {
|
||||
Some(mem::take(&mut self.bytes))
|
||||
} else {
|
||||
Some(self.bytes.split_to(self.size))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(target_arch = "wasm32")))]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
|
||||
use super::BytesChunks;
|
||||
|
||||
#[test]
|
||||
fn bytes_chunks() {
|
||||
let bytes = Bytes::new();
|
||||
assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2]);
|
||||
assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2]);
|
||||
assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2, 3]);
|
||||
assert_eq!(
|
||||
BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
|
||||
[Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
|
||||
);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2, 3]);
|
||||
assert_eq!(
|
||||
BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
|
||||
[Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
|
||||
);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2, 3, 4]);
|
||||
assert_eq!(
|
||||
BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
|
||||
[Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
|
||||
);
|
||||
}
|
||||
}
|
||||
241
crates/matrix-sdk/src/http_client/mod.rs
Normal file
241
crates/matrix-sdk/src/http_client/mod.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
any::type_name,
|
||||
fmt::Debug,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytesize::ByteSize;
|
||||
use eyeball::shared::Observable as SharedObservable;
|
||||
use ruma::{
|
||||
api::{
|
||||
error::{FromHttpResponseError, IntoHttpError},
|
||||
AuthScheme, MatrixVersion, OutgoingRequest, OutgoingRequestAppserviceExt, SendAccessToken,
|
||||
},
|
||||
UserId,
|
||||
};
|
||||
use tracing::{debug, field::debug, instrument, trace};
|
||||
|
||||
use crate::{config::RequestConfig, error::HttpError};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod native;
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
mod wasm;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) use native::HttpSettings;
|
||||
|
||||
pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct HttpClient {
|
||||
pub(crate) inner: reqwest::Client,
|
||||
pub(crate) request_config: RequestConfig,
|
||||
next_request_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
|
||||
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
|
||||
}
|
||||
|
||||
fn get_request_id(&self) -> String {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
|
||||
format!("REQ-{request_id}")
|
||||
}
|
||||
|
||||
fn serialize_request<R>(
|
||||
&self,
|
||||
request: R,
|
||||
config: RequestConfig,
|
||||
homeserver: String,
|
||||
access_token: Option<&str>,
|
||||
user_id: Option<&UserId>,
|
||||
server_versions: &[MatrixVersion],
|
||||
) -> Result<http::Request<Bytes>, IntoHttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
{
|
||||
trace!(request_type = type_name::<R>(), "Serializing request");
|
||||
|
||||
// We can't assert the identity without a user_id.
|
||||
let request = if let Some((access_token, user_id)) =
|
||||
access_token.filter(|_| config.assert_identity).zip(user_id)
|
||||
{
|
||||
request.try_into_http_request_with_user_id::<BytesMut>(
|
||||
&homeserver,
|
||||
SendAccessToken::Always(access_token),
|
||||
user_id,
|
||||
server_versions,
|
||||
)?
|
||||
} else {
|
||||
let send_access_token = match access_token {
|
||||
Some(access_token) => {
|
||||
if config.force_auth {
|
||||
SendAccessToken::Always(access_token)
|
||||
} else {
|
||||
SendAccessToken::IfRequired(access_token)
|
||||
}
|
||||
}
|
||||
None => SendAccessToken::None,
|
||||
};
|
||||
|
||||
request.try_into_http_request::<BytesMut>(
|
||||
&homeserver,
|
||||
send_access_token,
|
||||
server_versions,
|
||||
)?
|
||||
};
|
||||
|
||||
let request = request.map(|body| body.freeze());
|
||||
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(
|
||||
skip(self, access_token, config, request, user_id, send_progress),
|
||||
fields(
|
||||
config,
|
||||
path,
|
||||
user_id,
|
||||
request_size,
|
||||
request_body,
|
||||
request_id,
|
||||
status,
|
||||
response_size,
|
||||
)
|
||||
)]
|
||||
pub async fn send<R>(
|
||||
&self,
|
||||
request: R,
|
||||
config: Option<RequestConfig>,
|
||||
homeserver: String,
|
||||
access_token: Option<&str>,
|
||||
user_id: Option<&UserId>,
|
||||
server_versions: &[MatrixVersion],
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<R::IncomingResponse, HttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<R::EndpointError>>,
|
||||
{
|
||||
let config = match config {
|
||||
Some(config) => config,
|
||||
None => self.request_config,
|
||||
};
|
||||
|
||||
// Keep some local variables in a separate scope so the compiler doesn't include
|
||||
// them in the future type. https://github.com/rust-lang/rust/issues/57478
|
||||
let request = {
|
||||
let request_id = self.get_request_id();
|
||||
let span = tracing::Span::current();
|
||||
|
||||
// At this point in the code, the config isn't behind an Option anymore, that's
|
||||
// why we record it here, instead of in the #[instrument] macro.
|
||||
span.record("config", debug(config)).record("request_id", request_id);
|
||||
|
||||
// The user ID is only used if we're an app-service. Only log the user_id if
|
||||
// it's `Some` and if assert_identity is set.
|
||||
if config.assert_identity {
|
||||
span.record("user_id", user_id.map(debug));
|
||||
}
|
||||
|
||||
let auth_scheme = R::METADATA.authentication;
|
||||
if !matches!(auth_scheme, AuthScheme::AccessToken | AuthScheme::None) {
|
||||
return Err(HttpError::NotClientRequest);
|
||||
}
|
||||
|
||||
let request = self.serialize_request(
|
||||
request,
|
||||
config,
|
||||
homeserver,
|
||||
access_token,
|
||||
user_id,
|
||||
server_versions,
|
||||
)?;
|
||||
|
||||
let request_size = ByteSize(request.body().len().try_into().unwrap_or(u64::MAX));
|
||||
span.record("request_size", request_size.to_string_as(true));
|
||||
|
||||
// Since sliding sync is experimental, and the proxy might not do what we expect
|
||||
// it to do given a specific request body, it's useful to log the
|
||||
// request body here. This doesn't contain any personal information.
|
||||
// TODO: Remove this once sliding sync isn't experimental anymore.
|
||||
#[cfg(feature = "experimental-sliding-sync")]
|
||||
if type_name::<R>() == "ruma_client_api::sync::sync_events::v4::Request" {
|
||||
span.record("request_body", debug(request.body()));
|
||||
span.record("path", request.uri().path_and_query().map(|p| p.as_str()));
|
||||
} else {
|
||||
span.record("path", request.uri().path());
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "experimental-sliding-sync"))]
|
||||
span.record("path", request.uri().path());
|
||||
|
||||
request
|
||||
};
|
||||
|
||||
debug!("Sending request");
|
||||
|
||||
// There's a bunch of state in send_request, factor out a pinned inner
|
||||
// future to reduce this size of futures that await this function.
|
||||
match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
|
||||
Ok(response) => {
|
||||
debug!("Got response");
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Error while sending request: {e:?}");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress of sending or receiving a payload.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct TransmissionProgress {
|
||||
/// How many bytes were already transferred.
|
||||
pub current: usize,
|
||||
/// How many bytes there are in total.
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
async fn response_to_http_response(
|
||||
mut response: reqwest::Response,
|
||||
) -> Result<http::Response<Bytes>, reqwest::Error> {
|
||||
let status = response.status();
|
||||
|
||||
let mut http_builder = http::Response::builder().status(status);
|
||||
let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
|
||||
|
||||
for (k, v) in response.headers_mut().drain() {
|
||||
if let Some(key) = k {
|
||||
headers.insert(key, v);
|
||||
}
|
||||
}
|
||||
|
||||
let body = response.bytes().await?;
|
||||
|
||||
Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
|
||||
}
|
||||
260
crates/matrix-sdk/src/http_client/native.rs
Normal file
260
crates/matrix-sdk/src/http_client/native.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
mem,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
|
||||
use bytes::Bytes;
|
||||
use bytesize::ByteSize;
|
||||
use eyeball::shared::Observable as SharedObservable;
|
||||
use ruma::api::{
|
||||
client::error::{ErrorBody as ClientApiErrorBody, ErrorKind as ClientApiErrorKind},
|
||||
error::FromHttpResponseError,
|
||||
IncomingResponse, OutgoingRequest,
|
||||
};
|
||||
|
||||
use super::{response_to_http_response, HttpClient, TransmissionProgress, DEFAULT_REQUEST_TIMEOUT};
|
||||
use crate::{config::RequestConfig, error::HttpError, RumaApiError};
|
||||
|
||||
impl HttpClient {
|
||||
pub(super) async fn send_request<R>(
|
||||
&self,
|
||||
request: http::Request<Bytes>,
|
||||
config: RequestConfig,
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<R::IncomingResponse, HttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<R::EndpointError>>,
|
||||
{
|
||||
let backoff =
|
||||
ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() };
|
||||
let retry_count = AtomicU64::new(1);
|
||||
|
||||
let send_request = || {
|
||||
let send_progress = send_progress.clone();
|
||||
async {
|
||||
let stop = if let Some(retry_limit) = config.retry_limit {
|
||||
retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// Turn errors into permanent errors when the retry limit is reached
|
||||
let error_type = if stop {
|
||||
RetryError::Permanent
|
||||
} else {
|
||||
|err: HttpError| {
|
||||
if let Some(api_error) = err.as_ruma_api_error() {
|
||||
let status_code = match api_error {
|
||||
RumaApiError::ClientApi(e) => match e.body {
|
||||
ClientApiErrorBody::Standard {
|
||||
kind: ClientApiErrorKind::LimitExceeded { retry_after_ms },
|
||||
..
|
||||
} => {
|
||||
return RetryError::Transient {
|
||||
err,
|
||||
retry_after: retry_after_ms,
|
||||
};
|
||||
}
|
||||
_ => Some(e.status_code),
|
||||
},
|
||||
RumaApiError::Uiaa(_) => None,
|
||||
RumaApiError::Other(e) => Some(e.status_code),
|
||||
};
|
||||
|
||||
if let Some(status_code) = status_code {
|
||||
if status_code.is_server_error() {
|
||||
return RetryError::Transient { err, retry_after: None };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RetryError::Permanent(err)
|
||||
}
|
||||
};
|
||||
|
||||
let response = send_request(&self.inner, &request, config.timeout, send_progress)
|
||||
.await
|
||||
.map_err(error_type)?;
|
||||
|
||||
let status_code = response.status();
|
||||
let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
|
||||
tracing::Span::current()
|
||||
.record("status", status_code.as_u16())
|
||||
.record("response_size", response_size.to_string_as(true));
|
||||
|
||||
R::IncomingResponse::try_from_http_response(response)
|
||||
.map_err(|e| error_type(HttpError::from(e)))
|
||||
}
|
||||
};
|
||||
|
||||
retry::<_, HttpError, _, _, _>(backoff, send_request).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct HttpSettings {
|
||||
pub(crate) disable_ssl_verification: bool,
|
||||
pub(crate) proxy: Option<String>,
|
||||
pub(crate) user_agent: Option<String>,
|
||||
pub(crate) timeout: Duration,
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl Default for HttpSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
disable_ssl_verification: false,
|
||||
proxy: None,
|
||||
user_agent: None,
|
||||
timeout: DEFAULT_REQUEST_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
impl HttpSettings {
|
||||
/// Build a client with the specified configuration.
|
||||
pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
|
||||
let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
|
||||
let mut http_client =
|
||||
reqwest::Client::builder().user_agent(user_agent).timeout(self.timeout);
|
||||
|
||||
if self.disable_ssl_verification {
|
||||
http_client = http_client.danger_accept_invalid_certs(true)
|
||||
}
|
||||
|
||||
if let Some(p) = &self.proxy {
|
||||
http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
|
||||
}
|
||||
|
||||
Ok(http_client.build()?)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn send_request(
|
||||
client: &reqwest::Client,
|
||||
request: &http::Request<Bytes>,
|
||||
timeout: Duration,
|
||||
send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<http::Response<Bytes>, HttpError> {
|
||||
use std::convert::Infallible;
|
||||
|
||||
use futures_util::stream;
|
||||
|
||||
let request = clone_request(request);
|
||||
let request = {
|
||||
let mut request = if send_progress.subscriber_count() != 0 {
|
||||
send_progress.update(|p| p.total += request.body().len());
|
||||
reqwest::Request::try_from(request.map(|body| {
|
||||
let chunks = stream::iter(BytesChunks::new(body, 8192).map(
|
||||
move |chunk| -> Result<_, Infallible> {
|
||||
send_progress.update(|p| p.current += chunk.len());
|
||||
Ok(chunk)
|
||||
},
|
||||
));
|
||||
reqwest::Body::wrap_stream(chunks)
|
||||
}))?
|
||||
} else {
|
||||
reqwest::Request::try_from(request)?
|
||||
};
|
||||
|
||||
*request.timeout_mut() = Some(timeout);
|
||||
request
|
||||
};
|
||||
|
||||
let response = client.execute(request).await?;
|
||||
Ok(response_to_http_response(response).await?)
|
||||
}
|
||||
|
||||
// Clones all request parts except the extensions which can't be cloned.
|
||||
// See also https://github.com/hyperium/http/issues/395
|
||||
fn clone_request(request: &http::Request<Bytes>) -> http::Request<Bytes> {
|
||||
let mut builder = http::Request::builder()
|
||||
.version(request.version())
|
||||
.method(request.method())
|
||||
.uri(request.uri());
|
||||
*builder.headers_mut().unwrap() = request.headers().clone();
|
||||
builder.body(request.body().clone()).unwrap()
|
||||
}
|
||||
|
||||
struct BytesChunks {
|
||||
bytes: Bytes,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl BytesChunks {
|
||||
fn new(bytes: Bytes, size: usize) -> Self {
|
||||
assert_ne!(size, 0);
|
||||
Self { bytes, size }
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for BytesChunks {
|
||||
type Item = Bytes;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.bytes.is_empty() {
|
||||
None
|
||||
} else if self.bytes.len() < self.size {
|
||||
Some(mem::take(&mut self.bytes))
|
||||
} else {
|
||||
Some(self.bytes.split_to(self.size))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
|
||||
use super::BytesChunks;
|
||||
|
||||
#[test]
|
||||
fn bytes_chunks() {
|
||||
let bytes = Bytes::new();
|
||||
assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2]);
|
||||
assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2]);
|
||||
assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2, 3]);
|
||||
assert_eq!(
|
||||
BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
|
||||
[Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
|
||||
);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2, 3]);
|
||||
assert_eq!(
|
||||
BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
|
||||
[Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
|
||||
);
|
||||
|
||||
let bytes = Bytes::from_iter([1, 2, 3, 4]);
|
||||
assert_eq!(
|
||||
BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
|
||||
[Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
|
||||
);
|
||||
}
|
||||
}
|
||||
47
crates/matrix-sdk/src/http_client/wasm.rs
Normal file
47
crates/matrix-sdk/src/http_client/wasm.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytesize::ByteSize;
|
||||
use eyeball::shared::Observable as SharedObservable;
|
||||
use ruma::api::{error::FromHttpResponseError, IncomingResponse, OutgoingRequest};
|
||||
|
||||
use super::{response_to_http_response, HttpClient, TransmissionProgress};
|
||||
use crate::{config::RequestConfig, error::HttpError};
|
||||
|
||||
impl HttpClient {
|
||||
pub(super) async fn send_request<R>(
|
||||
&self,
|
||||
request: http::Request<Bytes>,
|
||||
_config: RequestConfig,
|
||||
_send_progress: SharedObservable<TransmissionProgress>,
|
||||
) -> Result<R::IncomingResponse, HttpError>
|
||||
where
|
||||
R: OutgoingRequest + Debug,
|
||||
HttpError: From<FromHttpResponseError<R::EndpointError>>,
|
||||
{
|
||||
let request = reqwest::Request::try_from(request)?;
|
||||
let response = response_to_http_response(self.inner.execute(request).await?).await?;
|
||||
|
||||
let status_code = response.status();
|
||||
let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
|
||||
tracing::Span::current()
|
||||
.record("status", status_code.as_u16())
|
||||
.record("response_size", response_size.to_string_as(true));
|
||||
|
||||
Ok(R::IncomingResponse::try_from_http_response(response)?)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user