From 5273fa6f423dcf2e916114db4f1b5949cbe8e408 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Mon, 19 Jun 2023 14:48:20 +0200 Subject: [PATCH] sdk: Move target-specific HTTP code into new submodules --- crates/matrix-sdk/src/http_client.rs | 512 -------------------- crates/matrix-sdk/src/http_client/mod.rs | 241 +++++++++ crates/matrix-sdk/src/http_client/native.rs | 260 ++++++++++ crates/matrix-sdk/src/http_client/wasm.rs | 47 ++ 4 files changed, 548 insertions(+), 512 deletions(-) delete mode 100644 crates/matrix-sdk/src/http_client.rs create mode 100644 crates/matrix-sdk/src/http_client/mod.rs create mode 100644 crates/matrix-sdk/src/http_client/native.rs create mode 100644 crates/matrix-sdk/src/http_client/wasm.rs diff --git a/crates/matrix-sdk/src/http_client.rs b/crates/matrix-sdk/src/http_client.rs deleted file mode 100644 index 956167fe5..000000000 --- a/crates/matrix-sdk/src/http_client.rs +++ /dev/null @@ -1,512 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{ - any::type_name, - fmt::Debug, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, - time::Duration, -}; - -use bytes::{Bytes, BytesMut}; -use bytesize::ByteSize; -use eyeball::shared::Observable as SharedObservable; -use ruma::{ - api::{ - error::{FromHttpResponseError, IntoHttpError}, - AuthScheme, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingRequestAppserviceExt, - SendAccessToken, - }, - UserId, -}; -use tracing::{debug, field::debug, instrument, trace}; - -use crate::{config::RequestConfig, error::HttpError}; - -pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30); - -#[derive(Debug)] -pub(crate) struct HttpClient { - pub(crate) inner: reqwest::Client, - pub(crate) request_config: RequestConfig, - next_request_id: Arc, -} - -impl HttpClient { - pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self { - HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() } - } - - fn get_request_id(&self) -> String { - let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); - format!("REQ-{request_id}") - } - - fn serialize_request( - &self, - request: R, - config: RequestConfig, - homeserver: String, - access_token: Option<&str>, - user_id: Option<&UserId>, - server_versions: &[MatrixVersion], - ) -> Result, IntoHttpError> - where - R: OutgoingRequest + Debug, - { - trace!(request_type = type_name::(), "Serializing request"); - - // We can't assert the identity without a user_id. - let request = if let Some((access_token, user_id)) = - access_token.filter(|_| config.assert_identity).zip(user_id) - { - request.try_into_http_request_with_user_id::( - &homeserver, - SendAccessToken::Always(access_token), - user_id, - server_versions, - )? - } else { - let send_access_token = match access_token { - Some(access_token) => { - if config.force_auth { - SendAccessToken::Always(access_token) - } else { - SendAccessToken::IfRequired(access_token) - } - } - None => SendAccessToken::None, - }; - - request.try_into_http_request::( - &homeserver, - send_access_token, - server_versions, - )? - }; - - let request = request.map(|body| body.freeze()); - - Ok(request) - } - - async fn send_request( - &self, - request: http::Request, - config: RequestConfig, - send_progress: SharedObservable, - ) -> Result - where - R: OutgoingRequest + Debug, - HttpError: From>, - { - // There's a bunch of state in send_request_with_retries, factor - // out a pinned inner future to reduce this size of futures that - // await this function. - #[cfg(not(target_arch = "wasm32"))] - let response = - Box::pin(self.send_request_with_retries::(config, request, send_progress)).await?; - - #[cfg(target_arch = "wasm32")] - let response = { - // Silence unused warning - _ = (config, send_progress); - - let request = reqwest::Request::try_from(request)?; - let response = response_to_http_response(self.inner.execute(request).await?).await?; - - let status_code = response.status(); - let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); - tracing::Span::current() - .record("status", status_code.as_u16()) - .record("response_size", response_size.to_string_as(true)); - - R::IncomingResponse::try_from_http_response(response)? - }; - - Ok(response) - } - - #[cfg(not(target_arch = "wasm32"))] - async fn send_request_with_retries( - &self, - config: RequestConfig, - request: http::Request, - send_progress: SharedObservable, - ) -> Result - where - R: OutgoingRequest + Debug, - HttpError: From>, - { - use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; - use ruma::api::client::error::{ - ErrorBody as ClientApiErrorBody, ErrorKind as ClientApiErrorKind, - }; - - use crate::RumaApiError; - - let backoff = - ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() }; - let retry_count = AtomicU64::new(1); - - let send_request = || { - let send_progress = send_progress.clone(); - async { - let stop = if let Some(retry_limit) = config.retry_limit { - retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit - } else { - false - }; - - // Turn errors into permanent errors when the retry limit is reached - let error_type = if stop { - RetryError::Permanent - } else { - |err: HttpError| { - if let Some(api_error) = err.as_ruma_api_error() { - let status_code = match api_error { - RumaApiError::ClientApi(e) => match e.body { - ClientApiErrorBody::Standard { - kind: ClientApiErrorKind::LimitExceeded { retry_after_ms }, - .. - } => { - return RetryError::Transient { - err, - retry_after: retry_after_ms, - }; - } - _ => Some(e.status_code), - }, - RumaApiError::Uiaa(_) => None, - RumaApiError::Other(e) => Some(e.status_code), - }; - - if let Some(status_code) = status_code { - if status_code.is_server_error() { - return RetryError::Transient { err, retry_after: None }; - } - } - } - - RetryError::Permanent(err) - } - }; - - let response = send_request( - &self.inner, - clone_request(&request), - config.timeout, - send_progress, - ) - .await - .map_err(error_type)?; - - let status_code = response.status(); - let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); - tracing::Span::current() - .record("status", status_code.as_u16()) - .record("response_size", response_size.to_string_as(true)); - - R::IncomingResponse::try_from_http_response(response) - .map_err(|e| error_type(HttpError::from(e))) - } - }; - - retry::<_, HttpError, _, _, _>(backoff, send_request).await - } - - #[allow(clippy::too_many_arguments)] - #[instrument( - skip(self, access_token, config, request, user_id, send_progress), - fields( - config, - path, - user_id, - request_size, - request_body, - request_id, - status, - response_size, - ) - )] - pub async fn send( - &self, - request: R, - config: Option, - homeserver: String, - access_token: Option<&str>, - user_id: Option<&UserId>, - server_versions: &[MatrixVersion], - send_progress: SharedObservable, - ) -> Result - where - R: OutgoingRequest + Debug, - HttpError: From>, - { - let config = match config { - Some(config) => config, - None => self.request_config, - }; - - // Keep some local variables in a separate scope so the compiler doesn't include - // them in the future type. https://github.com/rust-lang/rust/issues/57478 - let request = { - let request_id = self.get_request_id(); - let span = tracing::Span::current(); - - // At this point in the code, the config isn't behind an Option anymore, that's - // why we record it here, instead of in the #[instrument] macro. - span.record("config", debug(config)).record("request_id", request_id); - - // The user ID is only used if we're an app-service. Only log the user_id if - // it's `Some` and if assert_identity is set. - if config.assert_identity { - span.record("user_id", user_id.map(debug)); - } - - let auth_scheme = R::METADATA.authentication; - if !matches!(auth_scheme, AuthScheme::AccessToken | AuthScheme::None) { - return Err(HttpError::NotClientRequest); - } - - let request = self.serialize_request( - request, - config, - homeserver, - access_token, - user_id, - server_versions, - )?; - - let request_size = ByteSize(request.body().len().try_into().unwrap_or(u64::MAX)); - span.record("request_size", request_size.to_string_as(true)); - - // Since sliding sync is experimental, and the proxy might not do what we expect - // it to do given a specific request body, it's useful to log the - // request body here. This doesn't contain any personal information. - // TODO: Remove this once sliding sync isn't experimental anymore. - #[cfg(feature = "experimental-sliding-sync")] - if type_name::() == "ruma_client_api::sync::sync_events::v4::Request" { - span.record("request_body", debug(request.body())); - span.record("path", request.uri().path_and_query().map(|p| p.as_str())); - } else { - span.record("path", request.uri().path()); - } - - #[cfg(not(feature = "experimental-sliding-sync"))] - span.record("path", request.uri().path()); - - request - }; - - debug!("Sending request"); - match self.send_request::(request, config, send_progress).await { - Ok(response) => { - debug!("Got response"); - Ok(response) - } - Err(e) => { - debug!("Error while sending request: {e:?}"); - Err(e) - } - } - } -} - -#[cfg(not(target_arch = "wasm32"))] -#[derive(Clone, Debug)] -pub(crate) struct HttpSettings { - pub(crate) disable_ssl_verification: bool, - pub(crate) proxy: Option, - pub(crate) user_agent: Option, - pub(crate) timeout: Duration, -} - -#[cfg(not(target_arch = "wasm32"))] -impl Default for HttpSettings { - fn default() -> Self { - Self { - disable_ssl_verification: false, - proxy: None, - user_agent: None, - timeout: DEFAULT_REQUEST_TIMEOUT, - } - } -} - -#[cfg(not(target_arch = "wasm32"))] -impl HttpSettings { - /// Build a client with the specified configuration. - pub(crate) fn make_client(&self) -> Result { - let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned()); - let mut http_client = - reqwest::Client::builder().user_agent(user_agent).timeout(self.timeout); - - if self.disable_ssl_verification { - http_client = http_client.danger_accept_invalid_certs(true) - } - - if let Some(p) = &self.proxy { - http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?); - } - - Ok(http_client.build()?) - } -} - -/// Progress of sending or receiving a payload. -#[derive(Clone, Copy, Debug, Default)] -pub struct TransmissionProgress { - /// How many bytes were already transferred. - pub current: usize, - /// How many bytes there are in total. - pub total: usize, -} - -// Clones all request parts except the extensions which can't be cloned. -// See also https://github.com/hyperium/http/issues/395 -#[cfg(not(target_arch = "wasm32"))] -fn clone_request(request: &http::Request) -> http::Request { - let mut builder = http::Request::builder() - .version(request.version()) - .method(request.method()) - .uri(request.uri()); - *builder.headers_mut().unwrap() = request.headers().clone(); - builder.body(request.body().clone()).unwrap() -} - -async fn response_to_http_response( - mut response: reqwest::Response, -) -> Result, reqwest::Error> { - let status = response.status(); - - let mut http_builder = http::Response::builder().status(status); - let headers = http_builder.headers_mut().expect("Can't get the response builder headers"); - - for (k, v) in response.headers_mut().drain() { - if let Some(key) = k { - headers.insert(key, v); - } - } - - let body = response.bytes().await?; - - Ok(http_builder.body(body).expect("Can't construct a response using the given body")) -} - -#[cfg(not(target_arch = "wasm32"))] -async fn send_request( - client: &reqwest::Client, - request: http::Request, - timeout: Duration, - send_progress: SharedObservable, -) -> Result, HttpError> { - use std::convert::Infallible; - - use futures_util::stream; - - let request = { - let mut request = if send_progress.subscriber_count() != 0 { - send_progress.update(|p| p.total += request.body().len()); - reqwest::Request::try_from(request.map(|body| { - let chunks = stream::iter(BytesChunks::new(body, 8192).map( - move |chunk| -> Result<_, Infallible> { - send_progress.update(|p| p.current += chunk.len()); - Ok(chunk) - }, - )); - reqwest::Body::wrap_stream(chunks) - }))? - } else { - reqwest::Request::try_from(request)? - }; - - *request.timeout_mut() = Some(timeout); - request - }; - - let response = client.execute(request).await?; - Ok(response_to_http_response(response).await?) -} - -#[cfg(not(target_arch = "wasm32"))] -struct BytesChunks { - bytes: Bytes, - size: usize, -} - -#[cfg(not(target_arch = "wasm32"))] -impl BytesChunks { - fn new(bytes: Bytes, size: usize) -> Self { - assert_ne!(size, 0); - Self { bytes, size } - } -} - -#[cfg(not(target_arch = "wasm32"))] -impl Iterator for BytesChunks { - type Item = Bytes; - - fn next(&mut self) -> Option { - use std::mem; - - if self.bytes.is_empty() { - None - } else if self.bytes.len() < self.size { - Some(mem::take(&mut self.bytes)) - } else { - Some(self.bytes.split_to(self.size)) - } - } -} - -#[cfg(all(test, not(target_arch = "wasm32")))] -mod tests { - use bytes::Bytes; - - use super::BytesChunks; - - #[test] - fn bytes_chunks() { - let bytes = Bytes::new(); - assert!(BytesChunks::new(bytes, 1).collect::>().is_empty()); - - let bytes = Bytes::from_iter([1, 2]); - assert_eq!(BytesChunks::new(bytes, 2).collect::>(), [Bytes::from_iter([1, 2])]); - - let bytes = Bytes::from_iter([1, 2]); - assert_eq!(BytesChunks::new(bytes, 3).collect::>(), [Bytes::from_iter([1, 2])]); - - let bytes = Bytes::from_iter([1, 2, 3]); - assert_eq!( - BytesChunks::new(bytes, 1).collect::>(), - [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])] - ); - - let bytes = Bytes::from_iter([1, 2, 3]); - assert_eq!( - BytesChunks::new(bytes, 2).collect::>(), - [Bytes::from_iter([1, 2]), Bytes::from_iter([3])] - ); - - let bytes = Bytes::from_iter([1, 2, 3, 4]); - assert_eq!( - BytesChunks::new(bytes, 2).collect::>(), - [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])] - ); - } -} diff --git a/crates/matrix-sdk/src/http_client/mod.rs b/crates/matrix-sdk/src/http_client/mod.rs new file mode 100644 index 000000000..f6ea0caac --- /dev/null +++ b/crates/matrix-sdk/src/http_client/mod.rs @@ -0,0 +1,241 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::type_name, + fmt::Debug, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use bytes::{Bytes, BytesMut}; +use bytesize::ByteSize; +use eyeball::shared::Observable as SharedObservable; +use ruma::{ + api::{ + error::{FromHttpResponseError, IntoHttpError}, + AuthScheme, MatrixVersion, OutgoingRequest, OutgoingRequestAppserviceExt, SendAccessToken, + }, + UserId, +}; +use tracing::{debug, field::debug, instrument, trace}; + +use crate::{config::RequestConfig, error::HttpError}; + +#[cfg(not(target_arch = "wasm32"))] +mod native; +#[cfg(target_arch = "wasm32")] +mod wasm; + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) use native::HttpSettings; + +pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30); + +#[derive(Debug)] +pub(crate) struct HttpClient { + pub(crate) inner: reqwest::Client, + pub(crate) request_config: RequestConfig, + next_request_id: Arc, +} + +impl HttpClient { + pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self { + HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() } + } + + fn get_request_id(&self) -> String { + let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst); + format!("REQ-{request_id}") + } + + fn serialize_request( + &self, + request: R, + config: RequestConfig, + homeserver: String, + access_token: Option<&str>, + user_id: Option<&UserId>, + server_versions: &[MatrixVersion], + ) -> Result, IntoHttpError> + where + R: OutgoingRequest + Debug, + { + trace!(request_type = type_name::(), "Serializing request"); + + // We can't assert the identity without a user_id. + let request = if let Some((access_token, user_id)) = + access_token.filter(|_| config.assert_identity).zip(user_id) + { + request.try_into_http_request_with_user_id::( + &homeserver, + SendAccessToken::Always(access_token), + user_id, + server_versions, + )? + } else { + let send_access_token = match access_token { + Some(access_token) => { + if config.force_auth { + SendAccessToken::Always(access_token) + } else { + SendAccessToken::IfRequired(access_token) + } + } + None => SendAccessToken::None, + }; + + request.try_into_http_request::( + &homeserver, + send_access_token, + server_versions, + )? + }; + + let request = request.map(|body| body.freeze()); + + Ok(request) + } + + #[allow(clippy::too_many_arguments)] + #[instrument( + skip(self, access_token, config, request, user_id, send_progress), + fields( + config, + path, + user_id, + request_size, + request_body, + request_id, + status, + response_size, + ) + )] + pub async fn send( + &self, + request: R, + config: Option, + homeserver: String, + access_token: Option<&str>, + user_id: Option<&UserId>, + server_versions: &[MatrixVersion], + send_progress: SharedObservable, + ) -> Result + where + R: OutgoingRequest + Debug, + HttpError: From>, + { + let config = match config { + Some(config) => config, + None => self.request_config, + }; + + // Keep some local variables in a separate scope so the compiler doesn't include + // them in the future type. https://github.com/rust-lang/rust/issues/57478 + let request = { + let request_id = self.get_request_id(); + let span = tracing::Span::current(); + + // At this point in the code, the config isn't behind an Option anymore, that's + // why we record it here, instead of in the #[instrument] macro. + span.record("config", debug(config)).record("request_id", request_id); + + // The user ID is only used if we're an app-service. Only log the user_id if + // it's `Some` and if assert_identity is set. + if config.assert_identity { + span.record("user_id", user_id.map(debug)); + } + + let auth_scheme = R::METADATA.authentication; + if !matches!(auth_scheme, AuthScheme::AccessToken | AuthScheme::None) { + return Err(HttpError::NotClientRequest); + } + + let request = self.serialize_request( + request, + config, + homeserver, + access_token, + user_id, + server_versions, + )?; + + let request_size = ByteSize(request.body().len().try_into().unwrap_or(u64::MAX)); + span.record("request_size", request_size.to_string_as(true)); + + // Since sliding sync is experimental, and the proxy might not do what we expect + // it to do given a specific request body, it's useful to log the + // request body here. This doesn't contain any personal information. + // TODO: Remove this once sliding sync isn't experimental anymore. + #[cfg(feature = "experimental-sliding-sync")] + if type_name::() == "ruma_client_api::sync::sync_events::v4::Request" { + span.record("request_body", debug(request.body())); + span.record("path", request.uri().path_and_query().map(|p| p.as_str())); + } else { + span.record("path", request.uri().path()); + } + + #[cfg(not(feature = "experimental-sliding-sync"))] + span.record("path", request.uri().path()); + + request + }; + + debug!("Sending request"); + + // There's a bunch of state in send_request, factor out a pinned inner + // future to reduce this size of futures that await this function. + match Box::pin(self.send_request::(request, config, send_progress)).await { + Ok(response) => { + debug!("Got response"); + Ok(response) + } + Err(e) => { + debug!("Error while sending request: {e:?}"); + Err(e) + } + } + } +} + +/// Progress of sending or receiving a payload. +#[derive(Clone, Copy, Debug, Default)] +pub struct TransmissionProgress { + /// How many bytes were already transferred. + pub current: usize, + /// How many bytes there are in total. + pub total: usize, +} + +async fn response_to_http_response( + mut response: reqwest::Response, +) -> Result, reqwest::Error> { + let status = response.status(); + + let mut http_builder = http::Response::builder().status(status); + let headers = http_builder.headers_mut().expect("Can't get the response builder headers"); + + for (k, v) in response.headers_mut().drain() { + if let Some(key) = k { + headers.insert(key, v); + } + } + + let body = response.bytes().await?; + + Ok(http_builder.body(body).expect("Can't construct a response using the given body")) +} diff --git a/crates/matrix-sdk/src/http_client/native.rs b/crates/matrix-sdk/src/http_client/native.rs new file mode 100644 index 000000000..15893bee8 --- /dev/null +++ b/crates/matrix-sdk/src/http_client/native.rs @@ -0,0 +1,260 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + fmt::Debug, + mem, + sync::atomic::{AtomicU64, Ordering}, + time::Duration, +}; + +use backoff::{future::retry, Error as RetryError, ExponentialBackoff}; +use bytes::Bytes; +use bytesize::ByteSize; +use eyeball::shared::Observable as SharedObservable; +use ruma::api::{ + client::error::{ErrorBody as ClientApiErrorBody, ErrorKind as ClientApiErrorKind}, + error::FromHttpResponseError, + IncomingResponse, OutgoingRequest, +}; + +use super::{response_to_http_response, HttpClient, TransmissionProgress, DEFAULT_REQUEST_TIMEOUT}; +use crate::{config::RequestConfig, error::HttpError, RumaApiError}; + +impl HttpClient { + pub(super) async fn send_request( + &self, + request: http::Request, + config: RequestConfig, + send_progress: SharedObservable, + ) -> Result + where + R: OutgoingRequest + Debug, + HttpError: From>, + { + let backoff = + ExponentialBackoff { max_elapsed_time: config.retry_timeout, ..Default::default() }; + let retry_count = AtomicU64::new(1); + + let send_request = || { + let send_progress = send_progress.clone(); + async { + let stop = if let Some(retry_limit) = config.retry_limit { + retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit + } else { + false + }; + + // Turn errors into permanent errors when the retry limit is reached + let error_type = if stop { + RetryError::Permanent + } else { + |err: HttpError| { + if let Some(api_error) = err.as_ruma_api_error() { + let status_code = match api_error { + RumaApiError::ClientApi(e) => match e.body { + ClientApiErrorBody::Standard { + kind: ClientApiErrorKind::LimitExceeded { retry_after_ms }, + .. + } => { + return RetryError::Transient { + err, + retry_after: retry_after_ms, + }; + } + _ => Some(e.status_code), + }, + RumaApiError::Uiaa(_) => None, + RumaApiError::Other(e) => Some(e.status_code), + }; + + if let Some(status_code) = status_code { + if status_code.is_server_error() { + return RetryError::Transient { err, retry_after: None }; + } + } + } + + RetryError::Permanent(err) + } + }; + + let response = send_request(&self.inner, &request, config.timeout, send_progress) + .await + .map_err(error_type)?; + + let status_code = response.status(); + let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); + tracing::Span::current() + .record("status", status_code.as_u16()) + .record("response_size", response_size.to_string_as(true)); + + R::IncomingResponse::try_from_http_response(response) + .map_err(|e| error_type(HttpError::from(e))) + } + }; + + retry::<_, HttpError, _, _, _>(backoff, send_request).await + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[derive(Clone, Debug)] +pub(crate) struct HttpSettings { + pub(crate) disable_ssl_verification: bool, + pub(crate) proxy: Option, + pub(crate) user_agent: Option, + pub(crate) timeout: Duration, +} + +#[cfg(not(target_arch = "wasm32"))] +impl Default for HttpSettings { + fn default() -> Self { + Self { + disable_ssl_verification: false, + proxy: None, + user_agent: None, + timeout: DEFAULT_REQUEST_TIMEOUT, + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl HttpSettings { + /// Build a client with the specified configuration. + pub(crate) fn make_client(&self) -> Result { + let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned()); + let mut http_client = + reqwest::Client::builder().user_agent(user_agent).timeout(self.timeout); + + if self.disable_ssl_verification { + http_client = http_client.danger_accept_invalid_certs(true) + } + + if let Some(p) = &self.proxy { + http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?); + } + + Ok(http_client.build()?) + } +} + +pub(super) async fn send_request( + client: &reqwest::Client, + request: &http::Request, + timeout: Duration, + send_progress: SharedObservable, +) -> Result, HttpError> { + use std::convert::Infallible; + + use futures_util::stream; + + let request = clone_request(request); + let request = { + let mut request = if send_progress.subscriber_count() != 0 { + send_progress.update(|p| p.total += request.body().len()); + reqwest::Request::try_from(request.map(|body| { + let chunks = stream::iter(BytesChunks::new(body, 8192).map( + move |chunk| -> Result<_, Infallible> { + send_progress.update(|p| p.current += chunk.len()); + Ok(chunk) + }, + )); + reqwest::Body::wrap_stream(chunks) + }))? + } else { + reqwest::Request::try_from(request)? + }; + + *request.timeout_mut() = Some(timeout); + request + }; + + let response = client.execute(request).await?; + Ok(response_to_http_response(response).await?) +} + +// Clones all request parts except the extensions which can't be cloned. +// See also https://github.com/hyperium/http/issues/395 +fn clone_request(request: &http::Request) -> http::Request { + let mut builder = http::Request::builder() + .version(request.version()) + .method(request.method()) + .uri(request.uri()); + *builder.headers_mut().unwrap() = request.headers().clone(); + builder.body(request.body().clone()).unwrap() +} + +struct BytesChunks { + bytes: Bytes, + size: usize, +} + +impl BytesChunks { + fn new(bytes: Bytes, size: usize) -> Self { + assert_ne!(size, 0); + Self { bytes, size } + } +} + +impl Iterator for BytesChunks { + type Item = Bytes; + + fn next(&mut self) -> Option { + if self.bytes.is_empty() { + None + } else if self.bytes.len() < self.size { + Some(mem::take(&mut self.bytes)) + } else { + Some(self.bytes.split_to(self.size)) + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::BytesChunks; + + #[test] + fn bytes_chunks() { + let bytes = Bytes::new(); + assert!(BytesChunks::new(bytes, 1).collect::>().is_empty()); + + let bytes = Bytes::from_iter([1, 2]); + assert_eq!(BytesChunks::new(bytes, 2).collect::>(), [Bytes::from_iter([1, 2])]); + + let bytes = Bytes::from_iter([1, 2]); + assert_eq!(BytesChunks::new(bytes, 3).collect::>(), [Bytes::from_iter([1, 2])]); + + let bytes = Bytes::from_iter([1, 2, 3]); + assert_eq!( + BytesChunks::new(bytes, 1).collect::>(), + [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])] + ); + + let bytes = Bytes::from_iter([1, 2, 3]); + assert_eq!( + BytesChunks::new(bytes, 2).collect::>(), + [Bytes::from_iter([1, 2]), Bytes::from_iter([3])] + ); + + let bytes = Bytes::from_iter([1, 2, 3, 4]); + assert_eq!( + BytesChunks::new(bytes, 2).collect::>(), + [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])] + ); + } +} diff --git a/crates/matrix-sdk/src/http_client/wasm.rs b/crates/matrix-sdk/src/http_client/wasm.rs new file mode 100644 index 000000000..0071c2eaf --- /dev/null +++ b/crates/matrix-sdk/src/http_client/wasm.rs @@ -0,0 +1,47 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Debug; + +use bytes::Bytes; +use bytesize::ByteSize; +use eyeball::shared::Observable as SharedObservable; +use ruma::api::{error::FromHttpResponseError, IncomingResponse, OutgoingRequest}; + +use super::{response_to_http_response, HttpClient, TransmissionProgress}; +use crate::{config::RequestConfig, error::HttpError}; + +impl HttpClient { + pub(super) async fn send_request( + &self, + request: http::Request, + _config: RequestConfig, + _send_progress: SharedObservable, + ) -> Result + where + R: OutgoingRequest + Debug, + HttpError: From>, + { + let request = reqwest::Request::try_from(request)?; + let response = response_to_http_response(self.inner.execute(request).await?).await?; + + let status_code = response.status(); + let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX)); + tracing::Span::current() + .record("status", status_code.as_u16()) + .record("response_size", response_size.to_string_as(true)); + + Ok(R::IncomingResponse::try_from_http_response(response)?) + } +}