mirror of
https://github.com/redlib-org/redlib.git
synced 2026-06-11 12:44:16 -04:00
chore: further isolated hyper
This commit is contained in:
@@ -6,7 +6,7 @@ use arc_swap::ArcSwap;
|
||||
use cached::proc_macro::cached;
|
||||
use futures_lite::future::block_on;
|
||||
use futures_lite::{future::Boxed, FutureExt};
|
||||
use hyper::{body::Buf, header, Body, Request, Response};
|
||||
use hyper::{body::Buf, header, Body, Request as HyperRequest, Response as HyperResponse};
|
||||
use log::{error, trace, warn};
|
||||
use percent_encoding::{percent_encode, CONTROLS};
|
||||
use serde_json::Value;
|
||||
@@ -14,7 +14,7 @@ use std::result::Result;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU16};
|
||||
use std::sync::LazyLock;
|
||||
use wreq::{Client as WreqClient, Method};
|
||||
use wreq::{header as wreq_header, Client as WreqClient, Method, Response as WreqResponse};
|
||||
use wreq_util::Emulation;
|
||||
|
||||
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
|
||||
@@ -86,14 +86,14 @@ pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, S
|
||||
|
||||
let res = res.ok_or_else(|| "Unable to make HEAD request to Reddit.".to_string())?;
|
||||
let status = res.status().as_u16();
|
||||
let policy_error = res.headers().get(header::RETRY_AFTER).is_some();
|
||||
let policy_error = res.headers().get(wreq_header::RETRY_AFTER).is_some();
|
||||
|
||||
match status {
|
||||
// If Reddit responds with a 2xx, then the path is already canonical.
|
||||
200..=299 => Ok(Some(path)),
|
||||
|
||||
// If Reddit responds with a 301, then the path is redirected.
|
||||
301 => match res.headers().get(header::LOCATION) {
|
||||
301 => match res.headers().get(wreq_header::LOCATION) {
|
||||
Some(val) => {
|
||||
let Ok(original) = val.to_str() else {
|
||||
return Err("Unable to decode Location header.".to_string());
|
||||
@@ -131,13 +131,13 @@ pub async fn canonical_path(path: String, tries: i8) -> Result<Option<String>, S
|
||||
_ => Ok(
|
||||
res
|
||||
.headers()
|
||||
.get(header::LOCATION)
|
||||
.get(wreq_header::LOCATION)
|
||||
.map(|val| percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, String> {
|
||||
pub async fn proxy(req: HyperRequest<Body>, format: &str) -> Result<HyperResponse<Body>, String> {
|
||||
let mut url = format!("{format}?{}", req.uri().query().unwrap_or_default());
|
||||
|
||||
// For each parameter in request
|
||||
@@ -146,33 +146,6 @@ pub async fn proxy(req: Request<Body>, format: &str) -> Result<Response<Body>, S
|
||||
url = url.replace(&format!("{{{name}}}"), value);
|
||||
}
|
||||
|
||||
stream(&url, &req).await
|
||||
}
|
||||
|
||||
fn to_hyper_response(res: wreq::Response) -> Response<Body> {
|
||||
let status = res.status();
|
||||
let version = res.version();
|
||||
|
||||
let mut builder = Response::builder().status(status.as_u16()).version(match version {
|
||||
wreq::Version::HTTP_09 => hyper::Version::HTTP_09,
|
||||
wreq::Version::HTTP_10 => hyper::Version::HTTP_10,
|
||||
wreq::Version::HTTP_11 => hyper::Version::HTTP_11,
|
||||
wreq::Version::HTTP_2 => hyper::Version::HTTP_2,
|
||||
wreq::Version::HTTP_3 => hyper::Version::HTTP_3,
|
||||
_ => hyper::Version::HTTP_11,
|
||||
});
|
||||
|
||||
for (name, value) in res.headers() {
|
||||
builder = builder.header(
|
||||
header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(),
|
||||
header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
builder.body(Body::wrap_stream(res.bytes_stream())).unwrap()
|
||||
}
|
||||
|
||||
async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String> {
|
||||
// First parameter is target URL (mandatory).
|
||||
let wreq_uri = wreq::Uri::try_from(url).map_err(|_| "Couldn't parse URL".to_string())?;
|
||||
|
||||
@@ -212,19 +185,19 @@ async fn stream(url: &str, req: &Request<Body>) -> Result<Response<Body>, String
|
||||
rm("Nel");
|
||||
rm("Report-To");
|
||||
|
||||
to_hyper_response(res)
|
||||
res.into_hyper_response()
|
||||
})
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP
|
||||
/// 3xx codes Reddit returns and will automatically redirect.
|
||||
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<Response<Body>, String>> {
|
||||
fn reddit_get(path: String, quarantine: bool) -> Boxed<Result<WreqResponse, String>> {
|
||||
request(&Method::GET, path, true, quarantine, REDDIT_URL_BASE, REDDIT_URL_BASE_HOST)
|
||||
}
|
||||
|
||||
/// Makes a HEAD request to Reddit at `path, using the short URL base. This will not follow redirects.
|
||||
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
|
||||
fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<WreqResponse, String>> {
|
||||
request(&Method::HEAD, path, false, quarantine, base_path, host)
|
||||
}
|
||||
|
||||
@@ -237,7 +210,7 @@ fn reddit_short_head(path: String, quarantine: bool, base_path: &'static str, ho
|
||||
/// Makes a request to Reddit. If `redirect` is `true`, `request_with_redirect`
|
||||
/// will recurse on the URL that Reddit provides in the Location HTTP header
|
||||
/// in its response.
|
||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<Response<Body>, String>> {
|
||||
fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool, base_path: &'static str, host: &'static str) -> Boxed<Result<WreqResponse, String>> {
|
||||
// Build Reddit URL from path.
|
||||
let url = format!("{base_path}{path}");
|
||||
|
||||
@@ -276,7 +249,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
// redirect based on caller params.
|
||||
if response.status().is_redirection() {
|
||||
if !redirect {
|
||||
return Ok(to_hyper_response(response));
|
||||
return Ok(response);
|
||||
};
|
||||
let location_header = response.headers().get(wreq::header::LOCATION);
|
||||
if location_header.and_then(|h| h.to_str().ok()) == Some(ALTERNATIVE_REDDIT_URL_BASE) {
|
||||
@@ -314,7 +287,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
|
||||
.await;
|
||||
};
|
||||
|
||||
Ok(to_hyper_response(response))
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
|
||||
@@ -370,7 +343,7 @@ pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
|
||||
};
|
||||
|
||||
// asynchronously aggregate the chunks of the body
|
||||
match hyper::body::aggregate(response).await {
|
||||
match hyper::body::aggregate(response.into_hyper_response()).await {
|
||||
Ok(body) => {
|
||||
let has_remaining = body.has_remaining();
|
||||
|
||||
@@ -483,6 +456,35 @@ pub async fn rate_limit_check() -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
trait IntoHyperResponse {
|
||||
fn into_hyper_response(self) -> HyperResponse<Body>;
|
||||
}
|
||||
|
||||
impl IntoHyperResponse for WreqResponse {
|
||||
fn into_hyper_response(self) -> HyperResponse<Body> {
|
||||
let status = self.status();
|
||||
let version = self.version();
|
||||
|
||||
let mut builder = HyperResponse::builder().status(status.as_u16()).version(match version {
|
||||
wreq::Version::HTTP_09 => hyper::Version::HTTP_09,
|
||||
wreq::Version::HTTP_10 => hyper::Version::HTTP_10,
|
||||
wreq::Version::HTTP_11 => hyper::Version::HTTP_11,
|
||||
wreq::Version::HTTP_2 => hyper::Version::HTTP_2,
|
||||
wreq::Version::HTTP_3 => hyper::Version::HTTP_3,
|
||||
_ => hyper::Version::HTTP_11,
|
||||
});
|
||||
|
||||
for (name, value) in self.headers() {
|
||||
builder = builder.header(
|
||||
header::HeaderName::from_bytes(name.as_str().as_bytes()).unwrap(),
|
||||
header::HeaderValue::from_bytes(value.as_bytes()).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
builder.body(Body::wrap_stream(self.bytes_stream())).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user