diff --git a/src/client.rs b/src/client.rs
index b8c94b9..470c9d7 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -6,7 +6,7 @@ use arc_swap::ArcSwap;
use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
-use hyper::{body::Buf, header, Body, Request, Response};
+use hyper::{body::Buf, header, Body, Request as HyperRequest, Response as HyperResponse};
use log::{error, trace, warn};
use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
@@ -14,7 +14,7 @@ use std::result::Result;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::sync::LazyLock;
-use wreq::{Client as WreqClient, Method};
+use wreq::{header as wreq_header, Client as WreqClient, Method, Response as WreqResponse};
use wreq_util::Emulation;
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
@@ -86,14 +86,14 @@ pub async fn canonical_path(path: String, tries: i8) -> Result, S
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
let status = res.status().as_u16();
- let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
+ let policy_error = res.headers().get(wreq_header::RETRY_AFTER).is_some();
match status {
// If Reddit responds with a 2xx, then the path is already canonical.
200..=299 => Ok(Some(path)),
// If Reddit responds with a 301, then the path is redirected.
- 301 => match res.headers().get(header::LOCATION) {
+ 301 => match res.headers().get(wreq_header::LOCATION) {
Some(val) => {
let Ok(original) = val.to_str() else {
return Err("Unable to decode Location header.".to_string());
@@ -131,13 +131,13 @@ pub async fn canonical_path(path: String, tries: i8) -> Result , S
_ => Ok(
res
.headers()
- .get(header::LOCATION)
+ .get(wreq_header::LOCATION)
.map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()),
),
}
}
-pub async fn proxy(req: Request, format: &str) -> Result, String> {
+pub async fn proxy(req: HyperRequest, format: &str) -> Result, String> {
let mut url = format!("{format}?{}", req.uri().query().unwrap_or_default());
// For each parameter in request
@@ -146,33 +146,6 @@ pub async fn proxy(req: Request, format: &str) -> Result, S
url = url.replace(&format!("{{{name}}}"), value);
}
- stream(&url, &req).await
-}
-
-fn to_hyper_response(res: wreq::Response) -> Response {
- let status = res.status();
- let version = res.version();
-
- let mut builder = Response::builder().status(status.as_u16()).version(match version {
- wreq::Version::HTTP_09 => hyper::Version::HTTP_09,
- wreq::Version::HTTP_10 => hyper::Version::HTTP_10,
- wreq::Version::HTTP_11 => hyper::Version::HTTP_11,
- wreq::Version::HTTP_2 => hyper::Version::HTTP_2,
- wreq::Version::HTTP_3 => hyper::Version::HTTP_3,
- _ => hyper::Version::HTTP_11,
- });
-
- for (name, value) in res.headers() {
- builder = builder.header(
- header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(),
- header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
- );
- }
-
- builder.body(Body::wrap_stream(res.bytes_stream())).unwrap()
-}
-
-async fn stream(url: &str, req: &Request) -> Result, String> {
// First parameter is target URL (mandatory).
let wreq_uri = wreq::Uri::try_from(url).map_err(|_| "Couldn't parse URL".to_string())?;
@@ -212,19 +185,19 @@ async fn stream(url: &str, req: &Request) -> Result, String
rm("Nel");
rm("Report-To");
- to_hyper_response(res)
+ res.into_hyper_response()
})
.map_err(|e| e.to_string())
}
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
/// 3xx codes Reddit returns and will automatically redirect.
-fn reddit_get(path: String, quarantine: bool) -> Boxed, String>> {
+fn reddit_get(path: String, quarantine: bool) -> Boxed> {
request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
}
/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
-fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> {
+fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed> {
request(&Method::HEAD, path, false, quarantine, base_path, host)
}
@@ -237,7 +210,7 @@ fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, ho
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
/// will recurse on the URL that Reddit provides in the Location HTTP header
/// in its response.
-fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed, String>> {
+fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed> {
// Build Reddit URL from path.
let url = format!("{base_path}{path}");
@@ -276,7 +249,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// redirect based on caller params.
if response.status().is_redirection() {
if !redirect {
- return Ok(to_hyper_response(response));
+ return Ok(response);
};
let location_header = response.headers().get(wreq::header::LOCATION);
if location_header.and_then(|h| h.to_str().ok()) == Some(ALTERNATIVE_REDDIT_URL_BASE) {
@@ -314,7 +287,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
.await;
};
- Ok(to_hyper_response(response))
+ Ok(response)
}
Err(e) => {
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
@@ -370,7 +343,7 @@ pub async fn json(path: String, quarantine: bool) -> Result {
};
// asynchronously aggregate the chunks of the body
- match hyper::body::aggregate(response).await {
+ match hyper::body::aggregate(response.into_hyper_response()).await {
Ok(body) => {
let has_remaining = body.has_remaining();
@@ -483,6 +456,35 @@ pub async fn rate_limit_check() -> Result<(), String> {
Ok(())
}
+trait IntoHyperResponse {
+ fn into_hyper_response(self) -> HyperResponse;
+}
+
+impl IntoHyperResponse for WreqResponse {
+ fn into_hyper_response(self) -> HyperResponse {
+ let status = self.status();
+ let version = self.version();
+
+ let mut builder = HyperResponse::builder().status(status.as_u16()).version(match version {
+ wreq::Version::HTTP_09 => hyper::Version::HTTP_09,
+ wreq::Version::HTTP_10 => hyper::Version::HTTP_10,
+ wreq::Version::HTTP_11 => hyper::Version::HTTP_11,
+ wreq::Version::HTTP_2 => hyper::Version::HTTP_2,
+ wreq::Version::HTTP_3 => hyper::Version::HTTP_3,
+ _ => hyper::Version::HTTP_11,
+ });
+
+ for (name, value) in self.headers() {
+ builder = builder.header(
+ header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(),
+ header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
+ );
+ }
+
+ builder.body(Body::wrap_stream(self.bytes_stream())).unwrap()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;