diff --git a/src/client.rs b/src/client.rs index b8c94b9..470c9d7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,7 +6,7 @@ use arc_swap::ArcSwap; use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; -use hyper::{body::Buf, header, Body, Request, Response}; +use hyper::{body::Buf, header, Body, Request as HyperRequest, Response as HyperResponse}; use log::{error, trace, warn}; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; @@ -14,7 +14,7 @@ use std::result::Result; use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, AtomicU16}; use std::sync::LazyLock; -use wreq::{Client as WreqClient, Method}; +use wreq::{header as wreq_header, Client as WreqClient, Method, Response as WreqResponse}; use wreq_util::Emulation; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; @@ -86,14 +86,14 @@ pub async fn canonical_path(path: String, tries: i8) -> Result, S let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?; let status = res.status().as_u16(); - let policy_error = res.headers().get(header::RETRY_AFTER).is_some(); + let policy_error = res.headers().get(wreq_header::RETRY_AFTER).is_some(); match status { // If Reddit responds with a 2xx, then the path is already canonical. 200..=299 => Ok(Some(path)), // If Reddit responds with a 301, then the path is redirected. - 301 => match res.headers().get(header::LOCATION) { + 301 => match res.headers().get(wreq_header::LOCATION) { Some(val) => { let Ok(original) = val.to_str() else { return Err("Unable to decode Location header.".to_string()); @@ -131,13 +131,13 @@ pub async fn canonical_path(path: String, tries: i8) -> Result, S _ => Ok( res .headers() - .get(header::LOCATION) + .get(wreq_header::LOCATION) .map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()), ), } } -pub async fn proxy(req: Request, format: &str) -> Result, String> { +pub async fn proxy(req: HyperRequest, format: &str) -> Result, String> { let mut url = format!("{format}?{}", req.uri().query().unwrap_or_default()); // For each parameter in request @@ -146,33 +146,6 @@ pub async fn proxy(req: Request, format: &str) -> Result, S url = url.replace(&format!("{{{name}}}"), value); } - stream(&url, &req).await -} - -fn to_hyper_response(res: wreq::Response) -> Response { - let status = res.status(); - let version = res.version(); - - let mut builder = Response::builder().status(status.as_u16()).version(match version { - wreq::Version::HTTP_09 => hyper::Version::HTTP_09, - wreq::Version::HTTP_10 => hyper::Version::HTTP_10, - wreq::Version::HTTP_11 => hyper::Version::HTTP_11, - wreq::Version::HTTP_2 => hyper::Version::HTTP_2, - wreq::Version::HTTP_3 => hyper::Version::HTTP_3, - _ => hyper::Version::HTTP_11, - }); - - for (name, value) in res.headers() { - builder = builder.header( - header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(), - header::HeaderValue::from_bytes(value.as_bytes()).unwrap(), - ); - } - - builder.body(Body::wrap_stream(res.bytes_stream())).unwrap() -} - -async fn stream(url: &str, req: &Request) -> Result, String> { // First parameter is target URL (mandatory). let wreq_uri = wreq::Uri::try_from(url).map_err(|_| "Couldn't parse URL".to_string())?; @@ -212,19 +185,19 @@ async fn stream(url: &str, req: &Request) -> Result, String rm("Nel"); rm("Report-To"); - to_hyper_response(res) + res.into_hyper_response() }) .map_err(|e| e.to_string()) } /// Makes a GET request to Reddit at `path`. By default, this will honor HTTP /// 3xx codes Reddit returns and will automatically redirect. -fn reddit_get(path: String, quarantine: bool) -> Boxed, String>> { +fn reddit_get(path: String, quarantine: bool) -> Boxed> { request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST) } /// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects. -fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> { +fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed> { request(&Method::HEAD, path, false, quarantine, base_path, host) } @@ -237,7 +210,7 @@ fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, ho /// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect` /// will recurse on the URL that Reddit provides in the Location HTTP header /// in its response. -fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> { +fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed> { // Build Reddit URL from path. let url = format!("{base_path}{path}"); @@ -276,7 +249,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo // redirect based on caller params. if response.status().is_redirection() { if !redirect { - return Ok(to_hyper_response(response)); + return Ok(response); }; let location_header = response.headers().get(wreq::header::LOCATION); if location_header.and_then(|h| h.to_str().ok()) == Some(ALTERNATIVE_REDDIT_URL_BASE) { @@ -314,7 +287,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo .await; }; - Ok(to_hyper_response(response)) + Ok(response) } Err(e) => { dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e); @@ -370,7 +343,7 @@ pub async fn json(path: String, quarantine: bool) -> Result { }; // asynchronously aggregate the chunks of the body - match hyper::body::aggregate(response).await { + match hyper::body::aggregate(response.into_hyper_response()).await { Ok(body) => { let has_remaining = body.has_remaining(); @@ -483,6 +456,35 @@ pub async fn rate_limit_check() -> Result<(), String> { Ok(()) } +trait IntoHyperResponse { + fn into_hyper_response(self) -> HyperResponse; +} + +impl IntoHyperResponse for WreqResponse { + fn into_hyper_response(self) -> HyperResponse { + let status = self.status(); + let version = self.version(); + + let mut builder = HyperResponse::builder().status(status.as_u16()).version(match version { + wreq::Version::HTTP_09 => hyper::Version::HTTP_09, + wreq::Version::HTTP_10 => hyper::Version::HTTP_10, + wreq::Version::HTTP_11 => hyper::Version::HTTP_11, + wreq::Version::HTTP_2 => hyper::Version::HTTP_2, + wreq::Version::HTTP_3 => hyper::Version::HTTP_3, + _ => hyper::Version::HTTP_11, + }); + + for (name, value) in self.headers() { + builder = builder.header( + header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(), + header::HeaderValue::from_bytes(value.as_bytes()).unwrap(), + ); + } + + builder.body(Body::wrap_stream(self.bytes_stream())).unwrap() + } +} + #[cfg(test)] mod tests { use super::*;