chore: further isolated hyper

This commit is contained in:
Mark Lopez
2026-04-04 13:22:10 -05:00
parent c005e6f2d3
commit d2d32428d9

View File

@@ -6,7 +6,7 @@ use arc_swap::ArcSwap;
use cached::proc_macro::cached;
use futures_lite::future::block_on;
use futures_lite::{future::Boxed, FutureExt};
use hyper::{body::Buf, header, Body, Request, Response};
use hyper::{body::Buf, header, Body, Request as HyperRequest, Response as HyperResponse};
use log::{error, trace, warn};
use percent_encoding::{percent_encode, CONTROLS};
use serde_json::Value;
@@ -14,7 +14,7 @@ use std::result::Result;
use std::sync::atomic::Ordering;
use std::sync::atomic::{AtomicBool, AtomicU16};
use std::sync::LazyLock;
use wreq::{Client as WreqClient, Method};
use wreq::{header as wreq_header, Client as WreqClient, Method, Response as WreqResponse};
use wreq_util::Emulation;
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
@@ -86,14 +86,14 @@ pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, S
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
let status = res.status().as_u16();
let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
let policy_error = res.headers().get(wreq_header::RETRY_AFTER).is_some();
match status {
// If Reddit responds with a 2xx, then the path is already canonical.
200..=299 => Ok(Some(path)),
// If Reddit responds with a 301, then the path is redirected.
301 => match res.headers().get(header::LOCATION) {
301 => match res.headers().get(wreq_header::LOCATION) {
Some(val) => {
let Ok(original) = val.to_str() else {
return Err("Unable to decode Location header.".to_string());
@@ -131,13 +131,13 @@ pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, S
_ => Ok(
res
.headers()
.get(header::LOCATION)
.get(wreq_header::LOCATION)
.map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()),
),
}
}
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
pub async fn proxy(req: HyperRequest<Body>, format: &str) -> Result<HyperResponse<Body>, String> {
let mut url = format!("{format}?{}", req.uri().query().unwrap_or_default());
// For each parameter in request
@@ -146,33 +146,6 @@ pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, S
url = url.replace(&format!("{{{name}}}"), value);
}
stream(&url, &req).await
}
fn to_hyper_response(res: wreq::Response) -> Response<Body> {
let status = res.status();
let version = res.version();
let mut builder = Response::builder().status(status.as_u16()).version(match version {
wreq::Version::HTTP_09 => hyper::Version::HTTP_09,
wreq::Version::HTTP_10 => hyper::Version::HTTP_10,
wreq::Version::HTTP_11 => hyper::Version::HTTP_11,
wreq::Version::HTTP_2 => hyper::Version::HTTP_2,
wreq::Version::HTTP_3 => hyper::Version::HTTP_3,
_ => hyper::Version::HTTP_11,
});
for (name, value) in res.headers() {
builder = builder.header(
header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(),
header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
);
}
builder.body(Body::wrap_stream(res.bytes_stream())).unwrap()
}
async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String> {
// First parameter is target URL (mandatory).
let wreq_uri = wreq::Uri::try_from(url).map_err(|_| "Couldn't parse URL".to_string())?;
@@ -212,19 +185,19 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
rm("Nel");
rm("Report-To");
to_hyper_response(res)
res.into_hyper_response()
})
.map_err(|e| e.to_string())
}
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
/// 3xx codes Reddit returns and will automatically redirect.
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<WreqResponse, String>> {
request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
}
/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<WreqResponse, String>> {
request(&Method::HEAD, path, false, quarantine, base_path, host)
}
@@ -237,7 +210,7 @@ fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, ho
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
/// will recurse on the URL that Reddit provides in the Location HTTP header
/// in its response.
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<WreqResponse, String>> {
// Build Reddit URL from path.
let url = format!("{base_path}{path}");
@@ -276,7 +249,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
// redirect based on caller params.
if response.status().is_redirection() {
if !redirect {
return Ok(to_hyper_response(response));
return Ok(response);
};
let location_header = response.headers().get(wreq::header::LOCATION);
if location_header.and_then(|h| h.to_str().ok()) == Some(ALTERNATIVE_REDDIT_URL_BASE) {
@@ -314,7 +287,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
.await;
};
Ok(to_hyper_response(response))
Ok(response)
}
Err(e) => {
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
@@ -370,7 +343,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
};
// asynchronously aggregate the chunks of the body
match hyper::body::aggregate(response).await {
match hyper::body::aggregate(response.into_hyper_response()).await {
Ok(body) => {
let has_remaining = body.has_remaining();
@@ -483,6 +456,35 @@ pub async fn rate_limit_check() -> Result<(), String> {
Ok(())
}
trait IntoHyperResponse {
fn into_hyper_response(self) -> HyperResponse<Body>;
}
impl IntoHyperResponse for WreqResponse {
fn into_hyper_response(self) -> HyperResponse<Body> {
let status = self.status();
let version = self.version();
let mut builder = HyperResponse::builder().status(status.as_u16()).version(match version {
wreq::Version::HTTP_09 => hyper::Version::HTTP_09,
wreq::Version::HTTP_10 => hyper::Version::HTTP_10,
wreq::Version::HTTP_11 => hyper::Version::HTTP_11,
wreq::Version::HTTP_2 => hyper::Version::HTTP_2,
wreq::Version::HTTP_3 => hyper::Version::HTTP_3,
_ => hyper::Version::HTTP_11,
});
for (name, value) in self.headers() {
builder = builder.header(
header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(),
header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
);
}
builder.body(Body::wrap_stream(self.bytes_stream())).unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;