From 46933059f6019bb39d48e71d72fa46dd17a81eec Mon Sep 17 00:00:00 2001 From: Gregory Schier Date: Sat, 20 Dec 2025 14:10:55 -0800 Subject: [PATCH] Split up HTTP sending logic (#320) --- .gitignore | 1 + src-tauri/Cargo.lock | 44 +- src-tauri/Cargo.toml | 1 + src-tauri/src/error.rs | 2 +- src-tauri/src/http_request.rs | 883 +++++----------- src-tauri/src/render.rs | 2 +- src-tauri/yaak-common/Cargo.toml | 1 + src-tauri/yaak-common/src/lib.rs | 1 + src-tauri/yaak-common/src/serde.rs | 23 + src-tauri/yaak-http/Cargo.toml | 15 +- src-tauri/yaak-http/src/chained_reader.rs | 78 ++ src-tauri/yaak-http/src/client.rs | 25 +- src-tauri/yaak-http/src/decompress.rs | 188 ++++ src-tauri/yaak-http/src/error.rs | 15 + src-tauri/yaak-http/src/lib.rs | 6 + src-tauri/yaak-http/src/path_placeholders.rs | 6 +- src-tauri/yaak-http/src/proto.rs | 29 + src-tauri/yaak-http/src/sender.rs | 409 ++++++++ src-tauri/yaak-http/src/transaction.rs | 385 +++++++ src-tauri/yaak-http/src/types.rs | 975 ++++++++++++++++++ src-tauri/yaak-http/tests/test.txt | 1 + src-tauri/yaak-models/bindings/gen_models.ts | 2 +- ...251219074602_default-workspace-headers.sql | 15 + ...0251220000000_response-request-headers.sql | 3 + src-tauri/yaak-models/src/error.rs | 2 +- src-tauri/yaak-models/src/models.rs | 11 + src-tauri/yaak-plugins/bindings/gen_models.ts | 2 +- src-tauri/yaak-templates/src/renderer.rs | 6 + src-tauri/yaak-ws/src/commands.rs | 2 +- src-web/components/ExportDataDialog.tsx | 2 +- src-web/components/FolderLayout.tsx | 5 +- src-web/components/HeadersEditor.tsx | 4 +- src-web/components/HttpResponsePane.tsx | 163 +-- src-web/components/ResponseHeaders.tsx | 61 +- .../Settings/SettingsCertificates.tsx | 2 +- src-web/components/core/Button.tsx | 3 +- src-web/components/core/CountBadge.tsx | 13 +- src-web/components/core/DetailsBanner.tsx | 37 +- src-web/components/core/Icon.tsx | 4 + src-web/components/core/SizeTag.tsx | 11 +- src-web/lib/data/encodings.ts | 2 +- 41 files changed, 2708 insertions(+), 732 deletions(-) create mode 100644 src-tauri/yaak-common/src/serde.rs create mode 100644 src-tauri/yaak-http/src/chained_reader.rs create mode 100644 src-tauri/yaak-http/src/decompress.rs create mode 100644 src-tauri/yaak-http/src/proto.rs create mode 100644 src-tauri/yaak-http/src/sender.rs create mode 100644 src-tauri/yaak-http/src/transaction.rs create mode 100644 src-tauri/yaak-http/src/types.rs create mode 100644 src-tauri/yaak-http/tests/test.txt create mode 100644 src-tauri/yaak-models/migrations/20251219074602_default-workspace-headers.sql create mode 100644 src-tauri/yaak-models/migrations/20251220000000_response-request-headers.sql diff --git a/.gitignore b/.gitignore index ee2c7428..a0a25619 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ out .tmp tmp +.zed diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 30ab8329..a91f4c0a 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -192,12 +192,14 @@ version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b37fc50485c4f3f736a4fb14199f6d5f5ba008d7f28fe710306c92780f004c07" dependencies = [ - "brotli", + "brotli 8.0.1", "flate2", "futures-core", "memchr", "pin-project-lite", "tokio", + "zstd", + "zstd-safe", ] [[package]] @@ -536,6 +538,17 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "brotli" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor 4.0.3", +] + [[package]] name = "brotli" version = "8.0.1" @@ -544,7 +557,17 @@ checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", - "brotli-decompressor", + "brotli-decompressor 5.0.0", +] + +[[package]] +name = "brotli-decompressor" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a334ef7c9e23abf0ce748e8cd309037da93e606ad52eb372e4ce327a0dcfbdfd" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", ] [[package]] @@ -5762,7 +5785,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa9844cefcf99554a16e0a278156ae73b0d8680bbc0e2ad1e4287aadd8489cf" dependencies = [ "base64 0.22.1", - "brotli", + "brotli 8.0.1", "ico", "json-patch", "plist", @@ -6094,7 +6117,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76a423c51176eb3616ee9b516a9fa67fed5f0e78baaba680e44eb5dd2cc37490" dependencies = [ "anyhow", - "brotli", + "brotli 8.0.1", "cargo_metadata", "ctor", "dunce", @@ -7903,6 +7926,7 @@ dependencies = [ "thiserror 2.0.17", "tokio", "tokio-stream", + "tokio-util", "ts-rs", "uuid", "yaak-common", @@ -7929,6 +7953,7 @@ dependencies = [ "regex", "reqwest", "serde", + "serde_json", "tauri", "thiserror 2.0.17", ] @@ -8010,19 +8035,30 @@ dependencies = [ name = "yaak-http" version = "0.1.0" dependencies = [ + "async-compression", + "async-trait", + "brotli 7.0.0", + "bytes", + "flate2", + "futures-util", "hyper-util", "log", + "mime_guess", "regex", "reqwest", "reqwest_cookie_store", "serde", + "serde_json", "tauri", "thiserror 2.0.17", "tokio", + "tokio-util", "tower-service", "urlencoding", + "yaak-common", "yaak-models", "yaak-tls", + "zstd", ] [[package]] diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index af5ce6c2..19be0850 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -74,6 +74,7 @@ tauri-plugin-window-state = "2.4.1" thiserror = { workspace = true } tokio = { workspace = true, features = ["sync"] } tokio-stream = "0.1.17" +tokio-util = { version = "0.7", features = ["codec"] } ts-rs = { workspace = true } uuid = "1.12.1" yaak-common = { workspace = true } diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index 4797fb0d..e0541b51 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -59,7 +59,7 @@ pub enum Error { #[error("Request error: {0}")] RequestError(#[from] reqwest::Error), - #[error("Generic error: {0}")] + #[error("{0}")] GenericError(String), } diff --git a/src-tauri/src/http_request.rs b/src-tauri/src/http_request.rs index dfde8946..15cf6041 100644 --- a/src-tauri/src/http_request.rs +++ b/src-tauri/src/http_request.rs @@ -2,32 +2,25 @@ use crate::error::Error::GenericError; use crate::error::Result; use crate::render::render_http_request; use crate::response_err; -use http::header::{ACCEPT, USER_AGENT}; -use http::{HeaderMap, HeaderName, HeaderValue}; -use log::{debug, error, warn}; -use mime_guess::Mime; -use reqwest::{Method, Response}; -use reqwest::{Url, multipart}; +use log::debug; use reqwest_cookie_store::{CookieStore, CookieStoreMutex}; -use serde_json::Value; -use std::collections::BTreeMap; -use std::path::PathBuf; -use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; -use tauri::{Manager, Runtime, WebviewWindow}; -use tokio::fs; +use std::time::{Duration, Instant}; +use tauri::{AppHandle, Manager, Runtime, WebviewWindow}; use tokio::fs::{File, create_dir_all}; -use tokio::io::AsyncWriteExt; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Mutex; use tokio::sync::watch::Receiver; -use tokio::sync::{Mutex, oneshot}; use yaak_http::client::{ HttpConnectionOptions, HttpConnectionProxySetting, HttpConnectionProxySettingAuth, }; use yaak_http::manager::HttpConnectionManager; +use yaak_http::sender::ReqwestSender; +use yaak_http::transaction::HttpTransaction; +use yaak_http::types::{SendableHttpRequest, SendableHttpRequestOptions, append_query_params}; use yaak_models::models::{ - Cookie, CookieJar, Environment, HttpRequest, HttpResponse, HttpResponseHeader, - HttpResponseState, ProxySetting, ProxySettingAuth, + CookieJar, Environment, HttpRequest, HttpResponse, HttpResponseHeader, HttpResponseState, + ProxySetting, ProxySettingAuth, }; use yaak_models::query_manager::QueryManagerExt; use yaak_models::util::UpdateSource; @@ -36,7 +29,7 @@ use yaak_plugins::events::{ }; use yaak_plugins::manager::PluginManager; use yaak_plugins::template_callback::PluginTemplateCallback; -use yaak_templates::{RenderErrorBehavior, RenderOptions}; +use yaak_templates::RenderOptions; use yaak_tls::find_client_certificate; pub async fn send_http_request( @@ -65,62 +58,70 @@ pub async fn send_http_request_with_context( og_response: &HttpResponse, environment: Option, cookie_jar: Option, - cancelled_rx: &mut Receiver, + cancelled_rx: &Receiver, + plugin_context: &PluginContext, +) -> Result { + let app_handle = window.app_handle().clone(); + let response = Arc::new(Mutex::new(og_response.clone())); + let update_source = UpdateSource::from_window(window); + + // Execute the inner send logic and handle errors consistently + let result = send_http_request_inner( + window, + unrendered_request, + og_response, + environment, + cookie_jar, + cancelled_rx, + plugin_context, + ) + .await; + + match result { + Ok(response) => Ok(response), + Err(e) => { + Ok(response_err(&app_handle, &*response.lock().await, e.to_string(), &update_source)) + } + } +} + +async fn send_http_request_inner( + window: &WebviewWindow, + unrendered_request: &HttpRequest, + og_response: &HttpResponse, + environment: Option, + cookie_jar: Option, + cancelled_rx: &Receiver, plugin_context: &PluginContext, ) -> Result { let app_handle = window.app_handle().clone(); let plugin_manager = app_handle.state::(); let connection_manager = app_handle.state::(); let settings = window.db().get_settings(); - let workspace = window.db().get_workspace(&unrendered_request.workspace_id)?; - let environment_id = environment.map(|e| e.id); - let environment_chain = window.db().resolve_environments( - &unrendered_request.workspace_id, - unrendered_request.folder_id.as_deref(), - environment_id.as_deref(), - )?; - - let response_id = og_response.id.clone(); + let wrk_id = &unrendered_request.workspace_id; + let fld_id = unrendered_request.folder_id.as_deref(); + let env_id = environment.map(|e| e.id); + let resp_id = og_response.id.clone(); + let workspace = window.db().get_workspace(wrk_id)?; let response = Arc::new(Mutex::new(og_response.clone())); - let update_source = UpdateSource::from_window(window); - - let (resolved_request, auth_context_id) = match resolve_http_request(window, unrendered_request) - { - Ok(r) => r, - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - e.to_string(), - &update_source, - )); - } - }; - + let (resolved, auth_context_id) = resolve_http_request(window, unrendered_request)?; let cb = PluginTemplateCallback::new(window.app_handle(), &plugin_context, RenderPurpose::Send); + let env_chain = window.db().resolve_environments(&workspace.id, fld_id, env_id.as_deref())?; + let request = render_http_request(&resolved, env_chain, &cb, &RenderOptions::throw()).await?; - let opt = RenderOptions { error_behavior: RenderErrorBehavior::Throw }; - - let request = match render_http_request(&resolved_request, environment_chain, &cb, &opt).await { - Ok(r) => r, - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - e.to_string(), - &update_source, - )); - } + // Build the sendable request using the new SendableHttpRequest type + let options = SendableHttpRequestOptions { + follow_redirects: workspace.setting_follow_redirects, + timeout: if workspace.setting_request_timeout > 0 { + Some(Duration::from_millis(workspace.setting_request_timeout.unsigned_abs() as u64)) + } else { + None + }, }; + let mut sendable_request = SendableHttpRequest::from_http_request(&request, options).await?; - let mut url_string = request.url.clone(); - - url_string = ensure_proto(&url_string); - if !url_string.starts_with("http://") && !url_string.starts_with("https://") { - url_string = format!("http://{}", url_string); - } - debug!("Sending request to {} {url_string}", request.method); + debug!("Sending request to {} {}", sendable_request.method, sendable_request.url); let proxy_setting = match settings.proxy { None => HttpConnectionProxySetting::System, @@ -144,7 +145,8 @@ pub async fn send_http_request_with_context( } }; - let client_certificate = find_client_certificate(&url_string, &settings.client_certificates); + let client_certificate = + find_client_certificate(&sendable_request.url, &settings.client_certificates); // Add cookie store if specified let maybe_cookie_manager = match cookie_jar.clone() { @@ -175,523 +177,53 @@ pub async fn send_http_request_with_context( let client = connection_manager .get_client(&HttpConnectionOptions { id: plugin_context.id.clone(), - follow_redirects: workspace.setting_follow_redirects, validate_certificates: workspace.setting_validate_certificates, proxy: proxy_setting, cookie_provider: maybe_cookie_manager.as_ref().map(|(p, _)| Arc::clone(&p)), client_certificate, - timeout: if workspace.setting_request_timeout > 0 { - Some(Duration::from_millis(workspace.setting_request_timeout.unsigned_abs() as u64)) - } else { - None - }, }) .await?; - // Render query parameters - let mut query_params = Vec::new(); - for p in request.url_parameters.clone() { - if !p.enabled || p.name.is_empty() { - continue; - } - query_params.push((p.name, p.value)); + // Apply authentication to the request + apply_authentication( + &window, + &mut sendable_request, + &request, + auth_context_id, + &plugin_manager, + plugin_context, + ) + .await?; + + let start_for_cancellation = Instant::now(); + let final_resp = execute_transaction( + client, + sendable_request, + response.clone(), + &resp_id, + &app_handle, + &update_source, + cancelled_rx.clone(), + ) + .await; + + match final_resp { + Ok(r) => Ok(r), + Err(e) => match app_handle.db().get_http_response(&resp_id) { + Ok(mut r) => { + r.state = HttpResponseState::Closed; + r.elapsed = start_for_cancellation.elapsed().as_millis() as i32; + r.elapsed_headers = start_for_cancellation.elapsed().as_millis() as i32; + r.error = Some(e.to_string()); + app_handle + .db() + .update_http_response_if_id(&r, &UpdateSource::from_window(window)) + .expect("Failed to update response"); + Ok(r) + } + _ => Err(GenericError("Ephemeral request was cancelled".to_string())), + }, } - - let url = match Url::from_str(&url_string) { - Ok(u) => u, - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - format!("Failed to parse URL \"{}\": {}", url_string, e.to_string()), - &update_source, - )); - } - }; - - let m = Method::from_str(&request.method.to_uppercase()) - .map_err(|e| GenericError(e.to_string()))?; - let mut request_builder = client.request(m, url).query(&query_params); - - let mut headers = HeaderMap::new(); - headers.insert(USER_AGENT, HeaderValue::from_static("yaak")); - headers.insert(ACCEPT, HeaderValue::from_static("*/*")); - - // TODO: Set cookie header ourselves once we also handle redirects. We need to do this - // because reqwest doesn't give us a way to inspect the headers it sent (we have to do - // everything manually to know that). - // if let Some(cookie_store) = maybe_cookie_store.clone() { - // let values1 = cookie_store.get_request_values(&url); - // let raw_value = cookie_store.get_request_values(&url) - // .map(|(name, value)| format!("{}={}", name, value)) - // .collect::>() - // .join("; "); - // headers.insert( - // COOKIE, - // HeaderValue::from_str(&raw_value).expect("Failed to create cookie header"), - // ); - // } - - for h in request.headers.clone() { - if h.name.is_empty() && h.value.is_empty() { - continue; - } - - if !h.enabled { - continue; - } - - let header_name = match HeaderName::from_str(&h.name) { - Ok(n) => n, - Err(e) => { - error!("Failed to create header name: {}", e); - continue; - } - }; - let header_value = match HeaderValue::from_str(&h.value) { - Ok(n) => n, - Err(e) => { - error!("Failed to create header value: {}", e); - continue; - } - }; - - headers.insert(header_name, header_value); - } - - let request_body = request.body.clone(); - if let Some(body_type) = &request.body_type.clone() { - if body_type == "graphql" { - let query = get_str_h(&request_body, "query"); - let variables = get_str_h(&request_body, "variables"); - if request.method.to_lowercase() == "get" { - request_builder = request_builder.query(&[("query", query)]); - if !variables.trim().is_empty() { - request_builder = request_builder.query(&[("variables", variables)]); - } - } else { - let body = if variables.trim().is_empty() { - format!(r#"{{"query":{}}}"#, serde_json::to_string(query).unwrap_or_default()) - } else { - format!( - r#"{{"query":{},"variables":{variables}}}"#, - serde_json::to_string(query).unwrap_or_default() - ) - }; - request_builder = request_builder.body(body.to_owned()); - } - } else if body_type == "application/x-www-form-urlencoded" - && request_body.contains_key("form") - { - let mut form_params = Vec::new(); - let form = request_body.get("form"); - if let Some(f) = form { - match f.as_array() { - None => {} - Some(a) => { - for p in a { - let enabled = get_bool(p, "enabled", true); - let name = get_str(p, "name"); - if !enabled || name.is_empty() { - continue; - } - let value = get_str(p, "value"); - form_params.push((name, value)); - } - } - } - } - request_builder = request_builder.form(&form_params); - } else if body_type == "binary" && request_body.contains_key("filePath") { - let file_path = request_body - .get("filePath") - .ok_or(GenericError("filePath not set".to_string()))? - .as_str() - .unwrap_or_default(); - - match fs::read(file_path).await.map_err(|e| e.to_string()) { - Ok(f) => { - request_builder = request_builder.body(f); - } - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - e, - &update_source, - )); - } - } - } else if body_type == "multipart/form-data" && request_body.contains_key("form") { - let mut multipart_form = multipart::Form::new(); - if let Some(form_definition) = request_body.get("form") { - match form_definition.as_array() { - None => {} - Some(fd) => { - for p in fd { - let enabled = get_bool(p, "enabled", true); - let name = get_str(p, "name").to_string(); - - if !enabled || name.is_empty() { - continue; - } - - let file_path = get_str(p, "file").to_owned(); - let value = get_str(p, "value").to_owned(); - - let mut part = if file_path.is_empty() { - multipart::Part::text(value.clone()) - } else { - match fs::read(file_path.clone()).await { - Ok(f) => multipart::Part::bytes(f), - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - e.to_string(), - &update_source, - )); - } - } - }; - - let content_type = get_str(p, "contentType"); - - // Set or guess mimetype - if !content_type.is_empty() { - part = match part.mime_str(content_type) { - Ok(p) => p, - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - format!("Invalid mime for multi-part entry {e:?}"), - &update_source, - )); - } - }; - } else if !file_path.is_empty() { - let default_mime = - Mime::from_str("application/octet-stream").unwrap(); - let mime = - mime_guess::from_path(file_path.clone()).first_or(default_mime); - part = match part.mime_str(mime.essence_str()) { - Ok(p) => p, - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - format!("Invalid mime for multi-part entry {e:?}"), - &update_source, - )); - } - }; - } - - // Set a file path if it is not empty - if !file_path.is_empty() { - let user_filename = get_str(p, "filename").to_owned(); - let filename = if user_filename.is_empty() { - PathBuf::from(file_path) - .file_name() - .unwrap_or_default() - .to_string_lossy() - .to_string() - } else { - user_filename - }; - part = part.file_name(filename); - } - - multipart_form = multipart_form.part(name, part); - } - } - } - } - headers.remove("Content-Type"); // reqwest will add this automatically - request_builder = request_builder.multipart(multipart_form); - } else if request_body.contains_key("text") { - let body = get_str_h(&request_body, "text"); - request_builder = request_builder.body(body.to_owned()); - } else { - warn!("Unsupported body type: {}", body_type); - } - } else { - // No body set - let method = request.method.to_ascii_lowercase(); - let is_body_method = method == "post" || method == "put" || method == "patch"; - // Add Content-Length for methods that commonly accept a body because some servers - // will error if they don't receive it. - if is_body_method && !headers.contains_key("content-length") { - headers.insert("Content-Length", HeaderValue::from_static("0")); - } - } - - // Add headers last, because previous steps may modify them - request_builder = request_builder.headers(headers); - - let mut sendable_req = match request_builder.build() { - Ok(r) => r, - Err(e) => { - warn!("Failed to build request builder {e:?}"); - return Ok(response_err( - &app_handle, - &*response.lock().await, - e.to_string(), - &update_source, - )); - } - }; - - match request.authentication_type { - None => { - // No authentication found. Not even inherited - } - Some(authentication_type) if authentication_type == "none" => { - // Explicitly no authentication - } - Some(authentication_type) => { - let req = CallHttpAuthenticationRequest { - context_id: format!("{:x}", md5::compute(auth_context_id)), - values: serde_json::from_value(serde_json::to_value(&request.authentication)?)?, - url: sendable_req.url().to_string(), - method: sendable_req.method().to_string(), - headers: sendable_req - .headers() - .iter() - .map(|(name, value)| HttpHeader { - name: name.to_string(), - value: value.to_str().unwrap_or_default().to_string(), - }) - .collect(), - }; - let auth_result = plugin_manager - .call_http_authentication(&window, &authentication_type, req, plugin_context) - .await; - let plugin_result = match auth_result { - Ok(r) => r, - Err(e) => { - return Ok(response_err( - &app_handle, - &*response.lock().await, - e.to_string(), - &update_source, - )); - } - }; - - let headers = sendable_req.headers_mut(); - for header in plugin_result.set_headers.unwrap_or_default() { - match (HeaderName::from_str(&header.name), HeaderValue::from_str(&header.value)) { - (Ok(name), Ok(value)) => { - headers.insert(name, value); - } - _ => continue, - }; - } - - if let Some(params) = plugin_result.set_query_parameters { - let mut query_pairs = sendable_req.url_mut().query_pairs_mut(); - for p in params { - query_pairs.append_pair(&p.name, &p.value); - } - } - } - } - - let (resp_tx, resp_rx) = oneshot::channel::>(); - let (done_tx, done_rx) = oneshot::channel::(); - - let start = std::time::Instant::now(); - - tokio::spawn(async move { - let _ = resp_tx.send(client.execute(sendable_req).await); - }); - - let raw_response = tokio::select! { - Ok(r) = resp_rx => r, - _ = cancelled_rx.changed() => { - let mut r = response.lock().await; - r.elapsed_headers = start.elapsed().as_millis() as i32; - r.elapsed = start.elapsed().as_millis() as i32; - return Ok(response_err(&app_handle, &r, "Request was cancelled".to_string(), &update_source)); - } - }; - - { - let app_handle = app_handle.clone(); - let window = window.clone(); - let cancelled_rx = cancelled_rx.clone(); - let response_id = response_id.clone(); - let response = response.clone(); - let update_source = update_source.clone(); - tokio::spawn(async move { - match raw_response { - Ok(mut v) => { - let content_length = v.content_length(); - let response_headers = v.headers().clone(); - let dir = app_handle.path().app_data_dir().unwrap(); - let base_dir = dir.join("responses"); - create_dir_all(base_dir.clone()).await.expect("Failed to create responses dir"); - let body_path = if response_id.is_empty() { - base_dir.join(uuid::Uuid::new_v4().to_string()) - } else { - base_dir.join(response_id.clone()) - }; - - { - let mut r = response.lock().await; - r.body_path = Some(body_path.to_str().unwrap().to_string()); - r.elapsed_headers = start.elapsed().as_millis() as i32; - r.elapsed = start.elapsed().as_millis() as i32; - r.status = v.status().as_u16() as i32; - r.status_reason = v.status().canonical_reason().map(|s| s.to_string()); - r.headers = response_headers - .iter() - .map(|(k, v)| HttpResponseHeader { - name: k.as_str().to_string(), - value: v.to_str().unwrap_or_default().to_string(), - }) - .collect(); - r.url = v.url().to_string(); - r.remote_addr = v.remote_addr().map(|a| a.to_string()); - r.version = match v.version() { - reqwest::Version::HTTP_09 => Some("HTTP/0.9".to_string()), - reqwest::Version::HTTP_10 => Some("HTTP/1.0".to_string()), - reqwest::Version::HTTP_11 => Some("HTTP/1.1".to_string()), - reqwest::Version::HTTP_2 => Some("HTTP/2".to_string()), - reqwest::Version::HTTP_3 => Some("HTTP/3".to_string()), - _ => None, - }; - - r.state = HttpResponseState::Connected; - app_handle - .db() - .update_http_response_if_id(&r, &update_source) - .expect("Failed to update response after connected"); - } - - // Write body to FS - let mut f = File::options() - .create(true) - .truncate(true) - .write(true) - .open(&body_path) - .await - .expect("Failed to open file"); - - let mut written_bytes: usize = 0; - loop { - let chunk = v.chunk().await; - if *cancelled_rx.borrow() { - // Request was canceled - return; - } - match chunk { - Ok(Some(bytes)) => { - let mut r = response.lock().await; - r.elapsed = start.elapsed().as_millis() as i32; - f.write_all(&bytes).await.expect("Failed to write to file"); - f.flush().await.expect("Failed to flush file"); - written_bytes += bytes.len(); - r.content_length = Some(written_bytes as i32); - app_handle - .db() - .update_http_response_if_id(&r, &update_source) - .expect("Failed to update response"); - } - Ok(None) => { - break; - } - Err(e) => { - response_err( - &app_handle, - &*response.lock().await, - e.to_string(), - &update_source, - ); - break; - } - } - } - - // Set the final content length - { - let mut r = response.lock().await; - r.content_length = match content_length { - Some(l) => Some(l as i32), - None => Some(written_bytes as i32), - }; - r.state = HttpResponseState::Closed; - app_handle - .db() - .update_http_response_if_id(&r, &UpdateSource::from_window(&window)) - .expect("Failed to update response"); - }; - - // Add cookie store if specified - if let Some((cookie_store, mut cookie_jar)) = maybe_cookie_manager { - // let cookies = response_headers.get_all(SET_COOKIE).iter().map(|h| { - // println!("RESPONSE COOKIE: {}", h.to_str().unwrap()); - // cookie_store::RawCookie::from_str(h.to_str().unwrap()) - // .expect("Failed to parse cookie") - // }); - // store.store_response_cookies(cookies, &url); - - let json_cookies: Vec = cookie_store - .lock() - .unwrap() - .iter_any() - .map(|c| { - let json_cookie = - serde_json::to_value(&c).expect("Failed to serialize cookie"); - serde_json::from_value(json_cookie) - .expect("Failed to deserialize cookie") - }) - .collect::>(); - cookie_jar.cookies = json_cookies; - if let Err(e) = app_handle - .db() - .upsert_cookie_jar(&cookie_jar, &UpdateSource::from_window(&window)) - { - error!("Failed to update cookie jar: {}", e); - }; - } - } - Err(e) => { - warn!("Failed to execute request {e}"); - response_err( - &app_handle, - &*response.lock().await, - format!("{e} → {e:?}"), - &update_source, - ); - } - }; - - let r = response.lock().await.clone(); - done_tx.send(r).unwrap(); - }); - }; - - let app_handle = app_handle.clone(); - Ok(tokio::select! { - Ok(r) = done_rx => r, - _ = cancelled_rx.changed() => { - match app_handle.with_db(|c| c.get_http_response(&response_id)) { - Ok(mut r) => { - r.state = HttpResponseState::Closed; - r.elapsed = start.elapsed().as_millis() as i32; - r.elapsed_headers = start.elapsed().as_millis() as i32; - app_handle.db().update_http_response_if_id(&r, &UpdateSource::from_window(window)) - .expect("Failed to update response") - }, - _ => { - response_err(&app_handle, &*response.lock().await, "Ephemeral request was cancelled".to_string(), &update_source) - }.clone(), - } - } - }) } pub fn resolve_http_request( @@ -711,46 +243,191 @@ pub fn resolve_http_request( Ok((new_request, authentication_context_id)) } -fn ensure_proto(url_str: &str) -> String { - if url_str.starts_with("http://") || url_str.starts_with("https://") { - return url_str.to_string(); +async fn execute_transaction( + client: reqwest::Client, + sendable_request: SendableHttpRequest, + response: Arc>, + response_id: &String, + app_handle: &AppHandle, + update_source: &UpdateSource, + cancelled_rx: Receiver, +) -> Result { + let sender = ReqwestSender::with_client(client); + let transaction = HttpTransaction::new(sender); + let start = Instant::now(); + + // Capture request headers before sending (headers will be moved) + let request_headers: Vec = sendable_request + .headers + .iter() + .map(|(name, value)| HttpResponseHeader { name: name.clone(), value: value.clone() }) + .collect(); + + // Execute the transaction with cancellation support + // This returns the response with headers, but body is not yet consumed + let (http_response, _events) = + transaction.execute_with_cancellation(sendable_request, cancelled_rx.clone()).await?; + + // Prepare the response path before consuming the body + let dir = app_handle.path().app_data_dir()?; + let base_dir = dir.join("responses"); + create_dir_all(&base_dir).await?; + + let body_path = if response_id.is_empty() { + base_dir.join(uuid::Uuid::new_v4().to_string()) + } else { + base_dir.join(&response_id) + }; + + // Extract metadata before consuming the body (headers are available immediately) + let status = http_response.status; + let status_reason = http_response.status_reason.clone(); + let url = http_response.url.clone(); + let remote_addr = http_response.remote_addr.clone(); + let version = http_response.version.clone(); + let content_length = http_response.content_length; + let headers: Vec = http_response + .headers + .iter() + .map(|(name, value)| HttpResponseHeader { name: name.clone(), value: value.clone() }) + .collect(); + let headers_timing = http_response.timing.headers; + + // Update response with headers info and mark as connected + { + let mut r = response.lock().await; + r.body_path = Some( + body_path + .to_str() + .ok_or(GenericError(format!("Invalid path {body_path:?}")))? + .to_string(), + ); + r.elapsed_headers = headers_timing.as_millis() as i32; + r.elapsed = start.elapsed().as_millis() as i32; + r.status = status as i32; + r.status_reason = status_reason.clone(); + r.url = url.clone(); + r.remote_addr = remote_addr.clone(); + r.version = version.clone(); + r.headers = headers.clone(); + r.request_headers = request_headers.clone(); + r.state = HttpResponseState::Connected; + app_handle.db().update_http_response_if_id(&r, &update_source)?; } - // Url::from_str will fail without a proto, so add one - let parseable_url = format!("http://{}", url_str); - if let Ok(u) = Url::from_str(parseable_url.as_str()) { - match u.host() { - Some(host) => { - let h = host.to_string(); - // These TLDs force HTTPS - if h.ends_with(".app") || h.ends_with(".dev") || h.ends_with(".page") { - return format!("https://{url_str}"); - } + // Get the body stream for manual consumption + let mut body_stream = http_response.into_body_stream()?; + + // Open file for writing + let mut file = File::options() + .create(true) + .truncate(true) + .write(true) + .open(&body_path) + .await + .map_err(|e| GenericError(format!("Failed to open file: {}", e)))?; + + // Stream body to file, updating DB on each chunk + let mut written_bytes: usize = 0; + let mut buf = [0u8; 8192]; + + loop { + // Check for cancellation - if we already have headers/body, just close cleanly + if *cancelled_rx.borrow() { + break; + } + + match body_stream.read(&mut buf).await { + Ok(0) => break, // EOF + Ok(n) => { + file.write_all(&buf[..n]) + .await + .map_err(|e| GenericError(format!("Failed to write to file: {}", e)))?; + file.flush() + .await + .map_err(|e| GenericError(format!("Failed to flush file: {}", e)))?; + written_bytes += n; + + // Update response in DB with progress + let mut r = response.lock().await; + r.elapsed = start.elapsed().as_millis() as i32; + r.content_length = Some(written_bytes as i32); + app_handle.db().update_http_response_if_id(&r, &update_source)?; + } + Err(e) => { + return Err(GenericError(format!("Failed to read response body: {}", e))); } - None => {} } } - format!("http://{url_str}") + // Final update with closed state + let mut resp = response.lock().await.clone(); + resp.headers = headers; + resp.request_headers = request_headers; + resp.status = status as i32; + resp.status_reason = status_reason; + resp.url = url; + resp.remote_addr = remote_addr; + resp.version = version; + resp.state = HttpResponseState::Closed; + resp.content_length = match content_length { + Some(l) => Some(l as i32), + None => Some(written_bytes as i32), + }; + resp.elapsed = start.elapsed().as_millis() as i32; + resp.elapsed_headers = headers_timing.as_millis() as i32; + resp.body_path = Some( + body_path.to_str().ok_or(GenericError(format!("Invalid path {body_path:?}",)))?.to_string(), + ); + + app_handle.db().update_http_response_if_id(&resp, &update_source)?; + + Ok(resp) } -fn get_bool(v: &Value, key: &str, fallback: bool) -> bool { - match v.get(key) { - None => fallback, - Some(v) => v.as_bool().unwrap_or(fallback), - } -} +async fn apply_authentication( + window: &WebviewWindow, + sendable_request: &mut SendableHttpRequest, + request: &HttpRequest, + auth_context_id: String, + plugin_manager: &PluginManager, + plugin_context: &PluginContext, +) -> Result<()> { + match &request.authentication_type { + None => { + // No authentication found. Not even inherited + } + Some(authentication_type) if authentication_type == "none" => { + // Explicitly no authentication + } + Some(authentication_type) => { + let req = CallHttpAuthenticationRequest { + context_id: format!("{:x}", md5::compute(auth_context_id)), + values: serde_json::from_value(serde_json::to_value(&request.authentication)?)?, + url: sendable_request.url.clone(), + method: sendable_request.method.clone(), + headers: sendable_request + .headers + .iter() + .map(|(name, value)| HttpHeader { + name: name.to_string(), + value: value.to_string(), + }) + .collect(), + }; + let plugin_result = plugin_manager + .call_http_authentication(&window, &authentication_type, req, plugin_context) + .await?; -fn get_str<'a>(v: &'a Value, key: &str) -> &'a str { - match v.get(key) { - None => "", - Some(v) => v.as_str().unwrap_or_default(), - } -} + for header in plugin_result.set_headers.unwrap_or_default() { + sendable_request.insert_header((header.name, header.value)); + } -fn get_str_h<'a>(v: &'a BTreeMap, key: &str) -> &'a str { - match v.get(key) { - None => "", - Some(v) => v.as_str().unwrap_or_default(), + if let Some(params) = plugin_result.set_query_parameters { + let params = params.into_iter().map(|p| (p.name, p.value)).collect::>(); + sendable_request.url = append_query_params(&sendable_request.url, params); + } + } } + Ok(()) } diff --git a/src-tauri/src/render.rs b/src-tauri/src/render.rs index f915bda9..f5ad8fad 100644 --- a/src-tauri/src/render.rs +++ b/src-tauri/src/render.rs @@ -157,7 +157,7 @@ pub async fn render_http_request( let url = parse_and_render(r.url.clone().as_str(), vars, cb, &opt).await?; // This doesn't fit perfectly with the concept of "rendering" but it kind of does - let (url, url_parameters) = apply_path_placeholders(&url, url_parameters); + let (url, url_parameters) = apply_path_placeholders(&url, &url_parameters); Ok(HttpRequest { url, url_parameters, headers, body, authentication, ..r.to_owned() }) } diff --git a/src-tauri/yaak-common/Cargo.toml b/src-tauri/yaak-common/Cargo.toml index be033751..3e872f12 100644 --- a/src-tauri/yaak-common/Cargo.toml +++ b/src-tauri/yaak-common/Cargo.toml @@ -10,3 +10,4 @@ reqwest = { workspace = true, features = ["system-proxy", "gzip"] } thiserror = { workspace = true } regex = "1.11.0" serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } diff --git a/src-tauri/yaak-common/src/lib.rs b/src-tauri/yaak-common/src/lib.rs index d315fede..cac9bcc6 100644 --- a/src-tauri/yaak-common/src/lib.rs +++ b/src-tauri/yaak-common/src/lib.rs @@ -1,4 +1,5 @@ pub mod api_client; pub mod error; pub mod platform; +pub mod serde; pub mod window; diff --git a/src-tauri/yaak-common/src/serde.rs b/src-tauri/yaak-common/src/serde.rs new file mode 100644 index 00000000..683cc25d --- /dev/null +++ b/src-tauri/yaak-common/src/serde.rs @@ -0,0 +1,23 @@ +use serde_json::Value; +use std::collections::BTreeMap; + +pub fn get_bool(v: &Value, key: &str, fallback: bool) -> bool { + match v.get(key) { + None => fallback, + Some(v) => v.as_bool().unwrap_or(fallback), + } +} + +pub fn get_str<'a>(v: &'a Value, key: &str) -> &'a str { + match v.get(key) { + None => "", + Some(v) => v.as_str().unwrap_or_default(), + } +} + +pub fn get_str_map<'a>(v: &'a BTreeMap, key: &str) -> &'a str { + match v.get(key) { + None => "", + Some(v) => v.as_str().unwrap_or_default(), + } +} diff --git a/src-tauri/yaak-http/Cargo.toml b/src-tauri/yaak-http/Cargo.toml index c51ff89f..ec7367a5 100644 --- a/src-tauri/yaak-http/Cargo.toml +++ b/src-tauri/yaak-http/Cargo.toml @@ -5,16 +5,27 @@ edition = "2024" publish = false [dependencies] +async-compression = { version = "0.4", features = ["tokio", "gzip", "deflate", "brotli", "zstd"] } +async-trait = "0.1" +brotli = "7" +bytes = "1.5.0" +flate2 = "1" +futures-util = "0.3" +zstd = "0.13" hyper-util = { version = "0.1.17", default-features = false, features = ["client-legacy"] } log = { workspace = true } +mime_guess = "2.0.5" regex = "1.11.1" -reqwest = { workspace = true, features = ["multipart", "cookies", "gzip", "brotli", "deflate", "json", "rustls-tls-manual-roots-no-provider", "socks", "http2"] } +reqwest = { workspace = true, features = ["cookies", "rustls-tls-manual-roots-no-provider", "socks", "http2", "stream"] } reqwest_cookie_store = { workspace = true } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } tauri = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { version = "0.7", features = ["codec", "io", "io-util"] } tower-service = "0.3.3" urlencoding = "2.1.3" +yaak-common = { workspace = true } yaak-models = { workspace = true } yaak-tls = { workspace = true } diff --git a/src-tauri/yaak-http/src/chained_reader.rs b/src-tauri/yaak-http/src/chained_reader.rs new file mode 100644 index 00000000..7b5f10a3 --- /dev/null +++ b/src-tauri/yaak-http/src/chained_reader.rs @@ -0,0 +1,78 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, ReadBuf}; + +/// A stream that chains multiple AsyncRead sources together +pub(crate) struct ChainedReader { + readers: Vec, + current_index: usize, + current_reader: Option>, +} + +#[derive(Clone)] +pub(crate) enum ReaderType { + Bytes(Vec), + FilePath(String), +} + +impl ChainedReader { + pub(crate) fn new(readers: Vec) -> Self { + Self { readers, current_index: 0, current_reader: None } + } +} + +impl AsyncRead for ChainedReader { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + loop { + // Try to read from current reader if we have one + if let Some(ref mut reader) = self.current_reader { + let before_len = buf.filled().len(); + return match Pin::new(reader).poll_read(cx, buf) { + Poll::Ready(Ok(())) => { + if buf.filled().len() == before_len && buf.remaining() > 0 { + // Current reader is exhausted, move to next + self.current_reader = None; + continue; + } + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; + } + + // We need to get the next reader + if self.current_index >= self.readers.len() { + // No more readers + return Poll::Ready(Ok(())); + } + + // Get the next reader + let reader_type = self.readers[self.current_index].clone(); + self.current_index += 1; + + match reader_type { + ReaderType::Bytes(bytes) => { + self.current_reader = Some(Box::new(io::Cursor::new(bytes))); + } + ReaderType::FilePath(path) => { + // We need to handle file opening synchronously in poll_read + // This is a limitation - we'll use blocking file open + match std::fs::File::open(&path) { + Ok(file) => { + // Convert std File to tokio File + let tokio_file = tokio::fs::File::from_std(file); + self.current_reader = Some(Box::new(tokio_file)); + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + } + } +} diff --git a/src-tauri/yaak-http/src/client.rs b/src-tauri/yaak-http/src/client.rs index f8d06acc..f43ebaaf 100644 --- a/src-tauri/yaak-http/src/client.rs +++ b/src-tauri/yaak-http/src/client.rs @@ -1,11 +1,9 @@ use crate::dns::LocalhostResolver; use crate::error::Result; use log::{debug, info, warn}; -use reqwest::redirect::Policy; -use reqwest::{Client, Proxy}; +use reqwest::{Client, Proxy, redirect}; use reqwest_cookie_store::CookieStoreMutex; use std::sync::Arc; -use std::time::Duration; use yaak_tls::{ClientCertificateConfig, get_tls_config}; #[derive(Clone)] @@ -29,11 +27,9 @@ pub enum HttpConnectionProxySetting { #[derive(Clone)] pub struct HttpConnectionOptions { pub id: String, - pub follow_redirects: bool, pub validate_certificates: bool, pub proxy: HttpConnectionProxySetting, pub cookie_provider: Option>, - pub timeout: Option, pub client_certificate: Option, } @@ -41,9 +37,11 @@ impl HttpConnectionOptions { pub(crate) fn build_client(&self) -> Result { let mut client = Client::builder() .connection_verbose(true) - .gzip(true) - .brotli(true) - .deflate(true) + .redirect(redirect::Policy::none()) + // Decompression is handled by HttpTransaction, not reqwest + .no_gzip() + .no_brotli() + .no_deflate() .referer(false) .tls_info(true); @@ -55,12 +53,6 @@ impl HttpConnectionOptions { // Configure DNS resolver client = client.dns_resolver(LocalhostResolver::new()); - // Configure redirects - client = client.redirect(match self.follow_redirects { - true => Policy::limited(10), // TODO: Handle redirects natively - false => Policy::none(), - }); - // Configure cookie provider if let Some(p) = &self.cookie_provider { client = client.cookie_provider(Arc::clone(&p)); @@ -79,11 +71,6 @@ impl HttpConnectionOptions { } } - // Configure timeout - if let Some(d) = self.timeout { - client = client.timeout(d); - } - info!( "Building new HTTP client validate_certificates={} client_cert={}", self.validate_certificates, diff --git a/src-tauri/yaak-http/src/decompress.rs b/src-tauri/yaak-http/src/decompress.rs new file mode 100644 index 00000000..e3764ea6 --- /dev/null +++ b/src-tauri/yaak-http/src/decompress.rs @@ -0,0 +1,188 @@ +use crate::error::{Error, Result}; +use async_compression::tokio::bufread::{ + BrotliDecoder, DeflateDecoder as AsyncDeflateDecoder, GzipDecoder, + ZstdDecoder as AsyncZstdDecoder, +}; +use flate2::read::{DeflateDecoder, GzDecoder}; +use std::io::Read; +use tokio::io::{AsyncBufRead, AsyncRead}; + +/// Supported compression encodings +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ContentEncoding { + Gzip, + Deflate, + Brotli, + Zstd, + Identity, +} + +impl ContentEncoding { + /// Parse a Content-Encoding header value into an encoding type. + /// Returns Identity for unknown or missing encodings. + pub fn from_header(value: Option<&str>) -> Self { + match value.map(|s| s.trim().to_lowercase()).as_deref() { + Some("gzip") | Some("x-gzip") => ContentEncoding::Gzip, + Some("deflate") => ContentEncoding::Deflate, + Some("br") => ContentEncoding::Brotli, + Some("zstd") => ContentEncoding::Zstd, + _ => ContentEncoding::Identity, + } + } +} + +/// Result of decompression, containing both the decompressed data and size info +#[derive(Debug)] +pub struct DecompressResult { + pub data: Vec, + pub compressed_size: u64, + pub decompressed_size: u64, +} + +/// Decompress data based on the Content-Encoding. +/// Returns the original data unchanged if encoding is Identity or unknown. +pub fn decompress(data: Vec, encoding: ContentEncoding) -> Result { + let compressed_size = data.len() as u64; + + let decompressed = match encoding { + ContentEncoding::Identity => data, + ContentEncoding::Gzip => decompress_gzip(&data)?, + ContentEncoding::Deflate => decompress_deflate(&data)?, + ContentEncoding::Brotli => decompress_brotli(&data)?, + ContentEncoding::Zstd => decompress_zstd(&data)?, + }; + + let decompressed_size = decompressed.len() as u64; + + Ok(DecompressResult { data: decompressed, compressed_size, decompressed_size }) +} + +fn decompress_gzip(data: &[u8]) -> Result> { + let mut decoder = GzDecoder::new(data); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .map_err(|e| Error::DecompressionError(format!("gzip decompression failed: {}", e)))?; + Ok(decompressed) +} + +fn decompress_deflate(data: &[u8]) -> Result> { + let mut decoder = DeflateDecoder::new(data); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .map_err(|e| Error::DecompressionError(format!("deflate decompression failed: {}", e)))?; + Ok(decompressed) +} + +fn decompress_brotli(data: &[u8]) -> Result> { + let mut decompressed = Vec::new(); + brotli::BrotliDecompress(&mut std::io::Cursor::new(data), &mut decompressed) + .map_err(|e| Error::DecompressionError(format!("brotli decompression failed: {}", e)))?; + Ok(decompressed) +} + +fn decompress_zstd(data: &[u8]) -> Result> { + zstd::stream::decode_all(std::io::Cursor::new(data)) + .map_err(|e| Error::DecompressionError(format!("zstd decompression failed: {}", e))) +} + +/// Create a streaming decompressor that wraps an async reader. +/// Returns an AsyncRead that decompresses data on-the-fly. +pub fn streaming_decoder( + reader: R, + encoding: ContentEncoding, +) -> Box { + match encoding { + ContentEncoding::Identity => Box::new(reader), + ContentEncoding::Gzip => Box::new(GzipDecoder::new(reader)), + ContentEncoding::Deflate => Box::new(AsyncDeflateDecoder::new(reader)), + ContentEncoding::Brotli => Box::new(BrotliDecoder::new(reader)), + ContentEncoding::Zstd => Box::new(AsyncZstdDecoder::new(reader)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use flate2::Compression; + use flate2::write::GzEncoder; + use std::io::Write; + + #[test] + fn test_content_encoding_from_header() { + assert_eq!(ContentEncoding::from_header(Some("gzip")), ContentEncoding::Gzip); + assert_eq!(ContentEncoding::from_header(Some("x-gzip")), ContentEncoding::Gzip); + assert_eq!(ContentEncoding::from_header(Some("GZIP")), ContentEncoding::Gzip); + assert_eq!(ContentEncoding::from_header(Some("deflate")), ContentEncoding::Deflate); + assert_eq!(ContentEncoding::from_header(Some("br")), ContentEncoding::Brotli); + assert_eq!(ContentEncoding::from_header(Some("zstd")), ContentEncoding::Zstd); + assert_eq!(ContentEncoding::from_header(Some("identity")), ContentEncoding::Identity); + assert_eq!(ContentEncoding::from_header(Some("unknown")), ContentEncoding::Identity); + assert_eq!(ContentEncoding::from_header(None), ContentEncoding::Identity); + } + + #[test] + fn test_decompress_identity() { + let data = b"hello world".to_vec(); + let result = decompress(data.clone(), ContentEncoding::Identity).unwrap(); + assert_eq!(result.data, data); + assert_eq!(result.compressed_size, 11); + assert_eq!(result.decompressed_size, 11); + } + + #[test] + fn test_decompress_gzip() { + // Compress some data with gzip + let original = b"hello world, this is a test of gzip compression"; + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(original).unwrap(); + let compressed = encoder.finish().unwrap(); + + let result = decompress(compressed.clone(), ContentEncoding::Gzip).unwrap(); + assert_eq!(result.data, original); + assert_eq!(result.compressed_size, compressed.len() as u64); + assert_eq!(result.decompressed_size, original.len() as u64); + } + + #[test] + fn test_decompress_deflate() { + // Compress some data with deflate + let original = b"hello world, this is a test of deflate compression"; + let mut encoder = flate2::write::DeflateEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(original).unwrap(); + let compressed = encoder.finish().unwrap(); + + let result = decompress(compressed.clone(), ContentEncoding::Deflate).unwrap(); + assert_eq!(result.data, original); + assert_eq!(result.compressed_size, compressed.len() as u64); + assert_eq!(result.decompressed_size, original.len() as u64); + } + + #[test] + fn test_decompress_brotli() { + // Compress some data with brotli + let original = b"hello world, this is a test of brotli compression"; + let mut compressed = Vec::new(); + let mut writer = brotli::CompressorWriter::new(&mut compressed, 4096, 4, 22); + writer.write_all(original).unwrap(); + drop(writer); + + let result = decompress(compressed.clone(), ContentEncoding::Brotli).unwrap(); + assert_eq!(result.data, original); + assert_eq!(result.compressed_size, compressed.len() as u64); + assert_eq!(result.decompressed_size, original.len() as u64); + } + + #[test] + fn test_decompress_zstd() { + // Compress some data with zstd + let original = b"hello world, this is a test of zstd compression"; + let compressed = zstd::stream::encode_all(std::io::Cursor::new(original), 3).unwrap(); + + let result = decompress(compressed.clone(), ContentEncoding::Zstd).unwrap(); + assert_eq!(result.data, original); + assert_eq!(result.compressed_size, compressed.len() as u64); + assert_eq!(result.decompressed_size, original.len() as u64); + } +} diff --git a/src-tauri/yaak-http/src/error.rs b/src-tauri/yaak-http/src/error.rs index a4ef6ac0..bfb063a0 100644 --- a/src-tauri/yaak-http/src/error.rs +++ b/src-tauri/yaak-http/src/error.rs @@ -8,6 +8,21 @@ pub enum Error { #[error(transparent)] TlsError(#[from] yaak_tls::error::Error), + + #[error("Request failed with {0:?}")] + RequestError(String), + + #[error("Request canceled")] + RequestCanceledError, + + #[error("Timeout of {0:?} reached")] + RequestTimeout(std::time::Duration), + + #[error("Decompression error: {0}")] + DecompressionError(String), + + #[error("Failed to read response body: {0}")] + BodyReadError(String), } impl Serialize for Error { diff --git a/src-tauri/yaak-http/src/lib.rs b/src-tauri/yaak-http/src/lib.rs index cdc9fff1..387b7419 100644 --- a/src-tauri/yaak-http/src/lib.rs +++ b/src-tauri/yaak-http/src/lib.rs @@ -2,11 +2,17 @@ use crate::manager::HttpConnectionManager; use tauri::plugin::{Builder, TauriPlugin}; use tauri::{Manager, Runtime}; +mod chained_reader; pub mod client; +pub mod decompress; pub mod dns; pub mod error; pub mod manager; pub mod path_placeholders; +mod proto; +pub mod sender; +pub mod transaction; +pub mod types; pub fn init() -> TauriPlugin { Builder::new("yaak-http") diff --git a/src-tauri/yaak-http/src/path_placeholders.rs b/src-tauri/yaak-http/src/path_placeholders.rs index f80b4a5f..9a700e82 100644 --- a/src-tauri/yaak-http/src/path_placeholders.rs +++ b/src-tauri/yaak-http/src/path_placeholders.rs @@ -2,7 +2,7 @@ use yaak_models::models::HttpUrlParameter; pub fn apply_path_placeholders( url: &str, - parameters: Vec, + parameters: &Vec, ) -> (String, Vec) { let mut new_parameters = Vec::new(); @@ -18,7 +18,7 @@ pub fn apply_path_placeholders( // Remove as param if it modified the URL if old_url_string == *url { - new_parameters.push(p); + new_parameters.push(p.to_owned()); } } @@ -156,7 +156,7 @@ mod placeholder_tests { ..Default::default() }; - let (url, url_parameters) = apply_path_placeholders(&req.url, req.url_parameters); + let (url, url_parameters) = apply_path_placeholders(&req.url, &req.url_parameters); // Pattern match back to access it assert_eq!(url, "example.com/aaa/bar"); diff --git a/src-tauri/yaak-http/src/proto.rs b/src-tauri/yaak-http/src/proto.rs new file mode 100644 index 00000000..fd36a0ed --- /dev/null +++ b/src-tauri/yaak-http/src/proto.rs @@ -0,0 +1,29 @@ +use reqwest::Url; +use std::str::FromStr; + +pub(crate) fn ensure_proto(url_str: &str) -> String { + if url_str.is_empty() { + return "".to_string(); + } + + if url_str.starts_with("http://") || url_str.starts_with("https://") { + return url_str.to_string(); + } + + // Url::from_str will fail without a proto, so add one + let parseable_url = format!("http://{}", url_str); + if let Ok(u) = Url::from_str(parseable_url.as_str()) { + match u.host() { + Some(host) => { + let h = host.to_string(); + // These TLDs force HTTPS + if h.ends_with(".app") || h.ends_with(".dev") || h.ends_with(".page") { + return format!("https://{url_str}"); + } + } + None => {} + } + } + + format!("http://{url_str}") +} diff --git a/src-tauri/yaak-http/src/sender.rs b/src-tauri/yaak-http/src/sender.rs new file mode 100644 index 00000000..236dcb45 --- /dev/null +++ b/src-tauri/yaak-http/src/sender.rs @@ -0,0 +1,409 @@ +use crate::decompress::{ContentEncoding, streaming_decoder}; +use crate::error::{Error, Result}; +use crate::types::{SendableBody, SendableHttpRequest}; +use async_trait::async_trait; +use futures_util::StreamExt; +use reqwest::{Client, Method, Version}; +use std::collections::HashMap; +use std::fmt::Display; +use std::pin::Pin; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncRead, AsyncReadExt, BufReader}; +use tokio_util::io::StreamReader; + +#[derive(Debug, Default, Clone)] +pub struct HttpResponseTiming { + pub headers: Duration, + pub body: Duration, +} + +#[derive(Debug)] +pub enum HttpResponseEvent { + Setting(String, String), + Info(String), + SendUrl { method: String, path: String }, + ReceiveUrl { version: Version, status: String }, + HeaderUp(String, String), + HeaderDown(String, String), + HeaderUpDone, + HeaderDownDone, +} + +impl Display for HttpResponseEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HttpResponseEvent::Setting(name, value) => write!(f, "* Setting {}={}", name, value), + HttpResponseEvent::Info(s) => write!(f, "* {}", s), + HttpResponseEvent::SendUrl { method, path } => write!(f, "> {} {}", method, path), + HttpResponseEvent::ReceiveUrl { version, status } => { + write!(f, "< {} {}", version_to_str(version), status) + } + HttpResponseEvent::HeaderUp(name, value) => write!(f, "> {}: {}", name, value), + HttpResponseEvent::HeaderUpDone => write!(f, ">"), + HttpResponseEvent::HeaderDown(name, value) => write!(f, "< {}: {}", name, value), + HttpResponseEvent::HeaderDownDone => write!(f, "<"), + } + } +} + +/// Statistics about the body after consumption +#[derive(Debug, Default, Clone)] +pub struct BodyStats { + /// Size of the body as received over the wire (before decompression) + pub size_compressed: u64, + /// Size of the body after decompression + pub size_decompressed: u64, +} + +/// Type alias for the body stream +type BodyStream = Pin>; + +/// HTTP response with deferred body consumption. +/// Headers are available immediately after send(), body can be consumed in different ways. +/// Note: Debug is manually implemented since BodyStream doesn't implement Debug. +pub struct HttpResponse { + /// HTTP status code + pub status: u16, + /// HTTP status reason phrase (e.g., "OK", "Not Found") + pub status_reason: Option, + /// Response headers + pub headers: HashMap, + /// Content-Length from headers (may differ from actual body size) + pub content_length: Option, + /// Final URL (after redirects) + pub url: String, + /// Remote address of the server + pub remote_addr: Option, + /// HTTP version (e.g., "HTTP/1.1", "HTTP/2") + pub version: Option, + /// Timing information + pub timing: HttpResponseTiming, + + /// The body stream (consumed when calling bytes(), text(), write_to_file(), or drain()) + body_stream: Option, + /// Content-Encoding for decompression + encoding: ContentEncoding, + /// Start time for timing the body read + start_time: Instant, +} + +impl std::fmt::Debug for HttpResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HttpResponse") + .field("status", &self.status) + .field("status_reason", &self.status_reason) + .field("headers", &self.headers) + .field("content_length", &self.content_length) + .field("url", &self.url) + .field("remote_addr", &self.remote_addr) + .field("version", &self.version) + .field("timing", &self.timing) + .field("body_stream", &"") + .field("encoding", &self.encoding) + .finish() + } +} + +impl HttpResponse { + /// Create a new HttpResponse with an unconsumed body stream + #[allow(clippy::too_many_arguments)] + pub fn new( + status: u16, + status_reason: Option, + headers: HashMap, + content_length: Option, + url: String, + remote_addr: Option, + version: Option, + timing: HttpResponseTiming, + body_stream: BodyStream, + encoding: ContentEncoding, + start_time: Instant, + ) -> Self { + Self { + status, + status_reason, + headers, + content_length, + url, + remote_addr, + version, + timing, + body_stream: Some(body_stream), + encoding, + start_time, + } + } + + /// Consume the body and return it as bytes (loads entire body into memory). + /// Also decompresses the body if Content-Encoding is set. + pub async fn bytes(mut self) -> Result<(Vec, BodyStats, HttpResponseTiming)> { + let stream = self.body_stream.take().ok_or_else(|| { + Error::RequestError("Response body has already been consumed".to_string()) + })?; + + let buf_reader = BufReader::new(stream); + let mut decoder = streaming_decoder(buf_reader, self.encoding); + + let mut decompressed = Vec::new(); + let mut bytes_read = 0u64; + + // Read through the decoder in chunks to track compressed size + let mut buf = [0u8; 8192]; + loop { + match decoder.read(&mut buf).await { + Ok(0) => break, + Ok(n) => { + decompressed.extend_from_slice(&buf[..n]); + bytes_read += n as u64; + } + Err(e) => { + return Err(Error::BodyReadError(e.to_string())); + } + } + } + + let mut timing = self.timing.clone(); + timing.body = self.start_time.elapsed(); + + let stats = BodyStats { + // For now, we can't easily track compressed size when streaming through decoder + // Use content_length as an approximation, or decompressed size if identity encoding + size_compressed: self.content_length.unwrap_or(bytes_read), + size_decompressed: decompressed.len() as u64, + }; + + Ok((decompressed, stats, timing)) + } + + /// Consume the body and return it as a UTF-8 string. + pub async fn text(self) -> Result<(String, BodyStats, HttpResponseTiming)> { + let (bytes, stats, timing) = self.bytes().await?; + let text = String::from_utf8(bytes) + .map_err(|e| Error::RequestError(format!("Response is not valid UTF-8: {}", e)))?; + Ok((text, stats, timing)) + } + + /// Take the body stream for manual consumption. + /// Returns an AsyncRead that decompresses on-the-fly if Content-Encoding is set. + /// The caller is responsible for reading and processing the stream. + pub fn into_body_stream(mut self) -> Result> { + let stream = self.body_stream.take().ok_or_else(|| { + Error::RequestError("Response body has already been consumed".to_string()) + })?; + + let buf_reader = BufReader::new(stream); + let decoder = streaming_decoder(buf_reader, self.encoding); + + Ok(decoder) + } + + /// Discard the body without reading it (useful for redirects). + pub async fn drain(mut self) -> Result { + let stream = self.body_stream.take().ok_or_else(|| { + Error::RequestError("Response body has already been consumed".to_string()) + })?; + + // Just read and discard all bytes + let mut reader = stream; + let mut buf = [0u8; 8192]; + loop { + match reader.read(&mut buf).await { + Ok(0) => break, + Ok(_) => continue, + Err(e) => { + return Err(Error::RequestError(format!( + "Failed to drain response body: {}", + e + ))); + } + } + } + + let mut timing = self.timing.clone(); + timing.body = self.start_time.elapsed(); + + Ok(timing) + } +} + +/// Trait for sending HTTP requests +#[async_trait] +pub trait HttpSender: Send + Sync { + /// Send an HTTP request and return the response with headers. + /// The body is not consumed until you call bytes(), text(), write_to_file(), or drain(). + async fn send( + &self, + request: SendableHttpRequest, + events: &mut Vec, + ) -> Result; +} + +/// Reqwest-based implementation of HttpSender +pub struct ReqwestSender { + client: Client, +} + +impl ReqwestSender { + /// Create a new ReqwestSender with a default client + pub fn new() -> Result { + let client = Client::builder().build().map_err(Error::Client)?; + Ok(Self { client }) + } + + /// Create a new ReqwestSender with a custom client + pub fn with_client(client: Client) -> Self { + Self { client } + } +} + +#[async_trait] +impl HttpSender for ReqwestSender { + async fn send( + &self, + request: SendableHttpRequest, + events: &mut Vec, + ) -> Result { + // Parse the HTTP method + let method = Method::from_bytes(request.method.as_bytes()) + .map_err(|e| Error::RequestError(format!("Invalid HTTP method: {}", e)))?; + + // Build the request + let mut req_builder = self.client.request(method, &request.url); + + // Add headers + for header in request.headers { + req_builder = req_builder.header(&header.0, &header.1); + } + + // Configure timeout + if let Some(d) = request.options.timeout + && !d.is_zero() + { + req_builder = req_builder.timeout(d); + } + + // Add body + match request.body { + None => {} + Some(SendableBody::Bytes(bytes)) => { + req_builder = req_builder.body(bytes); + } + Some(SendableBody::Stream(stream)) => { + // Convert AsyncRead stream to reqwest Body + let stream = tokio_util::io::ReaderStream::new(stream); + let body = reqwest::Body::wrap_stream(stream); + req_builder = req_builder.body(body); + } + } + + let start = Instant::now(); + let mut timing = HttpResponseTiming::default(); + + // Send the request + let sendable_req = req_builder.build()?; + events.push(HttpResponseEvent::Setting( + "timeout".to_string(), + if request.options.timeout.unwrap_or_default().is_zero() { + "Infinity".to_string() + } else { + format!("{:?}", request.options.timeout) + }, + )); + + events.push(HttpResponseEvent::SendUrl { + path: sendable_req.url().path().to_string(), + method: sendable_req.method().to_string(), + }); + + for (name, value) in sendable_req.headers() { + events.push(HttpResponseEvent::HeaderUp( + name.to_string(), + value.to_str().unwrap_or_default().to_string(), + )); + } + events.push(HttpResponseEvent::HeaderUpDone); + events.push(HttpResponseEvent::Info("Sending request to server".to_string())); + + // Map some errors to our own, so they look nicer + let response = self.client.execute(sendable_req).await.map_err(|e| { + if reqwest::Error::is_timeout(&e) { + Error::RequestTimeout( + request.options.timeout.unwrap_or(Duration::from_secs(0)).clone(), + ) + } else { + Error::Client(e) + } + })?; + + let status = response.status().as_u16(); + let status_reason = response.status().canonical_reason().map(|s| s.to_string()); + let url = response.url().to_string(); + let remote_addr = response.remote_addr().map(|a| a.to_string()); + let version = Some(version_to_str(&response.version())); + + events.push(HttpResponseEvent::ReceiveUrl { + version: response.version(), + status: response.status().to_string(), + }); + + timing.headers = start.elapsed(); + + // Extract content length + let content_length = response.content_length(); + + // Extract headers + let mut headers = HashMap::new(); + for (key, value) in response.headers() { + if let Ok(v) = value.to_str() { + events.push(HttpResponseEvent::HeaderDown(key.to_string(), v.to_string())); + headers.insert(key.to_string(), v.to_string()); + } + } + events.push(HttpResponseEvent::HeaderDownDone); + + // Determine content encoding for decompression + // HTTP headers are case-insensitive, so we need to search for any casing + let encoding = ContentEncoding::from_header( + headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("content-encoding")) + .map(|(_, v)| v.as_str()), + ); + + // Get the byte stream instead of loading into memory + let byte_stream = response.bytes_stream(); + + // Convert the stream to an AsyncRead + let stream_reader = StreamReader::new( + byte_stream.map(|result| result.map_err(|e| std::io::Error::other(e))), + ); + + let body_stream: BodyStream = Box::pin(stream_reader); + + Ok(HttpResponse::new( + status, + status_reason, + headers, + content_length, + url, + remote_addr, + version, + timing, + body_stream, + encoding, + start, + )) + } +} + +fn version_to_str(version: &Version) -> String { + match *version { + Version::HTTP_09 => "HTTP/0.9".to_string(), + Version::HTTP_10 => "HTTP/1.0".to_string(), + Version::HTTP_11 => "HTTP/1.1".to_string(), + Version::HTTP_2 => "HTTP/2".to_string(), + Version::HTTP_3 => "HTTP/3".to_string(), + _ => "unknown".to_string(), + } +} diff --git a/src-tauri/yaak-http/src/transaction.rs b/src-tauri/yaak-http/src/transaction.rs new file mode 100644 index 00000000..de4da72a --- /dev/null +++ b/src-tauri/yaak-http/src/transaction.rs @@ -0,0 +1,385 @@ +use crate::error::Result; +use crate::sender::{HttpResponse, HttpResponseEvent, HttpSender}; +use crate::types::SendableHttpRequest; +use tokio::sync::watch::Receiver; + +/// HTTP Transaction that manages the lifecycle of a request, including redirect handling +pub struct HttpTransaction { + sender: S, + max_redirects: usize, +} + +impl HttpTransaction { + /// Create a new transaction with default settings + pub fn new(sender: S) -> Self { + Self { sender, max_redirects: 10 } + } + + /// Create a new transaction with custom max redirects + pub fn with_max_redirects(sender: S, max_redirects: usize) -> Self { + Self { sender, max_redirects } + } + + /// Execute the request with cancellation support. + /// Returns an HttpResponse with unconsumed body - caller decides how to consume it. + pub async fn execute_with_cancellation( + &self, + request: SendableHttpRequest, + mut cancelled_rx: Receiver, + ) -> Result<(HttpResponse, Vec)> { + let mut redirect_count = 0; + let mut current_url = request.url; + let mut current_method = request.method; + let mut current_headers = request.headers; + let mut current_body = request.body; + let mut events = Vec::new(); + + loop { + // Check for cancellation before each request + if *cancelled_rx.borrow() { + return Err(crate::error::Error::RequestCanceledError); + } + + // Build request for this iteration + let req = SendableHttpRequest { + url: current_url.clone(), + method: current_method.clone(), + headers: current_headers.clone(), + body: current_body, + options: request.options.clone(), + }; + + // Send the request + events.push(HttpResponseEvent::Setting( + "redirects".to_string(), + request.options.follow_redirects.to_string(), + )); + + // Execute with cancellation support + let response = tokio::select! { + result = self.sender.send(req, &mut events) => result?, + _ = cancelled_rx.changed() => { + return Err(crate::error::Error::RequestCanceledError); + } + }; + + if !Self::is_redirect(response.status) { + // Not a redirect - return the response for caller to consume body + return Ok((response, events)); + } + + if !request.options.follow_redirects { + // Redirects disabled - return the redirect response as-is + return Ok((response, events)); + } + + // Check if we've exceeded max redirects + if redirect_count >= self.max_redirects { + // Drain the response before returning error + let _ = response.drain().await; + return Err(crate::error::Error::RequestError(format!( + "Maximum redirect limit ({}) exceeded", + self.max_redirects + ))); + } + + // Extract Location header before draining (headers are available immediately) + // HTTP headers are case-insensitive, so we need to search for any casing + let location = response + .headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("location")) + .map(|(_, v)| v.clone()) + .ok_or_else(|| { + crate::error::Error::RequestError( + "Redirect response missing Location header".to_string(), + ) + })?; + + // Also get status before draining + let status = response.status; + + events.push(HttpResponseEvent::Info("Ignoring the response body".to_string())); + + // Drain the redirect response body before following + response.drain().await?; + + // Update the request URL + current_url = if location.starts_with("http://") || location.starts_with("https://") { + // Absolute URL + location + } else if location.starts_with('/') { + // Absolute path - need to extract base URL from current request + let base_url = Self::extract_base_url(¤t_url)?; + format!("{}{}", base_url, location) + } else { + // Relative path - need to resolve relative to current path + let base_path = Self::extract_base_path(¤t_url)?; + format!("{}/{}", base_path, location) + }; + + events.push(HttpResponseEvent::Info(format!( + "Issuing redirect {} to: {}", + redirect_count + 1, + current_url + ))); + + // Handle method changes for certain redirect codes + if status == 303 { + // 303 See Other always changes to GET + if current_method != "GET" { + current_method = "GET".to_string(); + events.push(HttpResponseEvent::Info("Changing method to GET".to_string())); + } + // Remove content-related headers + current_headers.retain(|h| { + let name_lower = h.0.to_lowercase(); + !name_lower.starts_with("content-") && name_lower != "transfer-encoding" + }); + } else if status == 301 || status == 302 { + // For 301/302, change POST to GET (common browser behavior) + // but keep other methods as-is + if current_method == "POST" { + events.push(HttpResponseEvent::Info("Changing method to GET".to_string())); + current_method = "GET".to_string(); + // Remove content-related headers + current_headers.retain(|h| { + let name_lower = h.0.to_lowercase(); + !name_lower.starts_with("content-") && name_lower != "transfer-encoding" + }); + } + } + // For 307 and 308, the method and body are preserved + + // Reset body for next iteration (since it was moved in the send call) + // For redirects that change method to GET or for all redirects since body was consumed + current_body = None; + + redirect_count += 1; + } + } + + /// Check if a status code indicates a redirect + fn is_redirect(status: u16) -> bool { + matches!(status, 301 | 302 | 303 | 307 | 308) + } + + /// Extract the base URL (scheme + host) from a full URL + fn extract_base_url(url: &str) -> Result { + // Find the position after "://" + let scheme_end = url.find("://").ok_or_else(|| { + crate::error::Error::RequestError(format!("Invalid URL format: {}", url)) + })?; + + // Find the first '/' after the scheme + let path_start = url[scheme_end + 3..].find('/'); + + if let Some(idx) = path_start { + Ok(url[..scheme_end + 3 + idx].to_string()) + } else { + // No path, return entire URL + Ok(url.to_string()) + } + } + + /// Extract the base path (everything except the last segment) from a URL + fn extract_base_path(url: &str) -> Result { + if let Some(last_slash) = url.rfind('/') { + // Don't include the trailing slash if it's part of the host + if url[..last_slash].ends_with("://") || url[..last_slash].ends_with(':') { + Ok(url.to_string()) + } else { + Ok(url[..last_slash].to_string()) + } + } else { + Ok(url.to_string()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::decompress::ContentEncoding; + use crate::sender::{HttpResponseEvent, HttpResponseTiming, HttpSender}; + use async_trait::async_trait; + use std::collections::HashMap; + use std::pin::Pin; + use std::sync::Arc; + use std::time::Instant; + use tokio::io::AsyncRead; + use tokio::sync::Mutex; + + /// Mock sender for testing + struct MockSender { + responses: Arc>>, + } + + struct MockResponse { + status: u16, + headers: HashMap, + body: Vec, + } + + impl MockSender { + fn new(responses: Vec) -> Self { + Self { responses: Arc::new(Mutex::new(responses)) } + } + } + + #[async_trait] + impl HttpSender for MockSender { + async fn send( + &self, + _request: SendableHttpRequest, + _events: &mut Vec, + ) -> Result { + let mut responses = self.responses.lock().await; + if responses.is_empty() { + Err(crate::error::Error::RequestError("No more mock responses".to_string())) + } else { + let mock = responses.remove(0); + // Create a simple in-memory stream from the body + let body_stream: Pin> = + Box::pin(std::io::Cursor::new(mock.body)); + Ok(HttpResponse::new( + mock.status, + None, // status_reason + mock.headers, + None, // content_length + "https://example.com".to_string(), // url + None, // remote_addr + Some("HTTP/1.1".to_string()), // version + HttpResponseTiming::default(), + body_stream, + ContentEncoding::Identity, + Instant::now(), + )) + } + } + } + + #[tokio::test] + async fn test_transaction_no_redirect() { + let response = MockResponse { status: 200, headers: HashMap::new(), body: b"OK".to_vec() }; + let sender = MockSender::new(vec![response]); + let transaction = HttpTransaction::new(sender); + + let request = SendableHttpRequest { + url: "https://example.com".to_string(), + method: "GET".to_string(), + headers: vec![], + ..Default::default() + }; + + let (_tx, rx) = tokio::sync::watch::channel(false); + let (result, _) = transaction.execute_with_cancellation(request, rx).await.unwrap(); + assert_eq!(result.status, 200); + + // Consume the body to verify it + let (body, _, _) = result.bytes().await.unwrap(); + assert_eq!(body, b"OK"); + } + + #[tokio::test] + async fn test_transaction_single_redirect() { + let mut redirect_headers = HashMap::new(); + redirect_headers.insert("Location".to_string(), "https://example.com/new".to_string()); + + let responses = vec![ + MockResponse { status: 302, headers: redirect_headers, body: vec![] }, + MockResponse { status: 200, headers: HashMap::new(), body: b"Final".to_vec() }, + ]; + + let sender = MockSender::new(responses); + let transaction = HttpTransaction::new(sender); + + let request = SendableHttpRequest { + url: "https://example.com/old".to_string(), + method: "GET".to_string(), + options: crate::types::SendableHttpRequestOptions { + follow_redirects: true, + ..Default::default() + }, + ..Default::default() + }; + + let (_tx, rx) = tokio::sync::watch::channel(false); + let (result, _) = transaction.execute_with_cancellation(request, rx).await.unwrap(); + assert_eq!(result.status, 200); + + let (body, _, _) = result.bytes().await.unwrap(); + assert_eq!(body, b"Final"); + } + + #[tokio::test] + async fn test_transaction_max_redirects_exceeded() { + let mut redirect_headers = HashMap::new(); + redirect_headers.insert("Location".to_string(), "https://example.com/loop".to_string()); + + // Create more redirects than allowed + let responses: Vec = (0..12) + .map(|_| MockResponse { status: 302, headers: redirect_headers.clone(), body: vec![] }) + .collect(); + + let sender = MockSender::new(responses); + let transaction = HttpTransaction::with_max_redirects(sender, 10); + + let request = SendableHttpRequest { + url: "https://example.com/start".to_string(), + method: "GET".to_string(), + options: crate::types::SendableHttpRequestOptions { + follow_redirects: true, + ..Default::default() + }, + ..Default::default() + }; + + let (_tx, rx) = tokio::sync::watch::channel(false); + let result = transaction.execute_with_cancellation(request, rx).await; + if let Err(crate::error::Error::RequestError(msg)) = result { + assert!(msg.contains("Maximum redirect limit")); + } else { + panic!("Expected RequestError with max redirect message. Got {result:?}"); + } + } + + #[test] + fn test_is_redirect() { + assert!(HttpTransaction::::is_redirect(301)); + assert!(HttpTransaction::::is_redirect(302)); + assert!(HttpTransaction::::is_redirect(303)); + assert!(HttpTransaction::::is_redirect(307)); + assert!(HttpTransaction::::is_redirect(308)); + assert!(!HttpTransaction::::is_redirect(200)); + assert!(!HttpTransaction::::is_redirect(404)); + assert!(!HttpTransaction::::is_redirect(500)); + } + + #[test] + fn test_extract_base_url() { + let result = + HttpTransaction::::extract_base_url("https://example.com/path/to/resource"); + assert_eq!(result.unwrap(), "https://example.com"); + + let result = HttpTransaction::::extract_base_url("http://localhost:8080/api"); + assert_eq!(result.unwrap(), "http://localhost:8080"); + + let result = HttpTransaction::::extract_base_url("invalid-url"); + assert!(result.is_err()); + } + + #[test] + fn test_extract_base_path() { + let result = HttpTransaction::::extract_base_path( + "https://example.com/path/to/resource", + ); + assert_eq!(result.unwrap(), "https://example.com/path/to"); + + let result = HttpTransaction::::extract_base_path("https://example.com/single"); + assert_eq!(result.unwrap(), "https://example.com"); + + let result = HttpTransaction::::extract_base_path("https://example.com/"); + assert_eq!(result.unwrap(), "https://example.com"); + } +} diff --git a/src-tauri/yaak-http/src/types.rs b/src-tauri/yaak-http/src/types.rs new file mode 100644 index 00000000..cd3432f4 --- /dev/null +++ b/src-tauri/yaak-http/src/types.rs @@ -0,0 +1,975 @@ +use crate::chained_reader::{ChainedReader, ReaderType}; +use crate::error::Error::RequestError; +use crate::error::Result; +use crate::path_placeholders::apply_path_placeholders; +use crate::proto::ensure_proto; +use bytes::Bytes; +use log::warn; +use std::collections::BTreeMap; +use std::pin::Pin; +use std::time::Duration; +use tokio::io::AsyncRead; +use yaak_common::serde::{get_bool, get_str, get_str_map}; +use yaak_models::models::HttpRequest; + +pub(crate) const MULTIPART_BOUNDARY: &str = "------YaakFormBoundary"; + +pub enum SendableBody { + Bytes(Bytes), + Stream(Pin>), +} + +enum SendableBodyWithMeta { + Bytes(Bytes), + Stream { + data: Pin>, + content_length: Option, + }, +} + +impl From for SendableBody { + fn from(value: SendableBodyWithMeta) -> Self { + match value { + SendableBodyWithMeta::Bytes(b) => SendableBody::Bytes(b), + SendableBodyWithMeta::Stream { data, .. } => SendableBody::Stream(data), + } + } +} + +#[derive(Default)] +pub struct SendableHttpRequest { + pub url: String, + pub method: String, + pub headers: Vec<(String, String)>, + pub body: Option, + pub options: SendableHttpRequestOptions, +} + +#[derive(Default, Clone)] +pub struct SendableHttpRequestOptions { + pub timeout: Option, + pub follow_redirects: bool, +} + +impl SendableHttpRequest { + pub async fn from_http_request( + r: &HttpRequest, + options: SendableHttpRequestOptions, + ) -> Result { + let initial_headers = build_headers(r); + let (body, headers) = build_body(&r.method, &r.body_type, &r.body, initial_headers).await?; + + Ok(Self { + url: build_url(r), + method: r.method.to_uppercase(), + headers, + body: body.into(), + options, + }) + } + + pub fn insert_header(&mut self, header: (String, String)) { + if let Some(existing) = + self.headers.iter_mut().find(|h| h.0.to_lowercase() == header.0.to_lowercase()) + { + existing.1 = header.1; + } else { + self.headers.push(header); + } + } +} + +pub fn append_query_params(url: &str, params: Vec<(String, String)>) -> String { + let url_string = url.to_string(); + if params.is_empty() { + return url.to_string(); + } + + // Build query string + let query_string = params + .iter() + .map(|(name, value)| { + format!("{}={}", urlencoding::encode(name), urlencoding::encode(value)) + }) + .collect::>() + .join("&"); + + // Split URL into parts: base URL, query, and fragment + let (base_and_query, fragment) = if let Some(hash_pos) = url_string.find('#') { + let (before_hash, after_hash) = url_string.split_at(hash_pos); + (before_hash.to_string(), Some(after_hash.to_string())) + } else { + (url_string, None) + }; + + // Now handle query parameters on the base URL (without fragment) + let mut result = if base_and_query.contains('?') { + // Check if there's already a query string after the '?' + let parts: Vec<&str> = base_and_query.splitn(2, '?').collect(); + if parts.len() == 2 && !parts[1].trim().is_empty() { + // Append with & if there are existing parameters + format!("{}&{}", base_and_query, query_string) + } else { + // Just append the new parameters directly (URL ends with '?') + format!("{}{}", base_and_query, query_string) + } + } else { + // No existing query parameters, add with '?' + format!("{}?{}", base_and_query, query_string) + }; + + // Re-append the fragment if it exists + if let Some(fragment) = fragment { + result.push_str(&fragment); + } + + result +} + +fn build_url(r: &HttpRequest) -> String { + let (url_string, params) = apply_path_placeholders(&ensure_proto(&r.url), &r.url_parameters); + append_query_params( + &url_string, + params + .iter() + .filter(|p| p.enabled && !p.name.is_empty()) + .map(|p| (p.name.clone(), p.value.clone())) + .collect(), + ) +} + +fn build_headers(r: &HttpRequest) -> Vec<(String, String)> { + r.headers + .iter() + .filter_map(|h| { + if h.enabled && !h.name.is_empty() { + Some((h.name.clone(), h.value.clone())) + } else { + None + } + }) + .collect() +} + +async fn build_body( + method: &str, + body_type: &Option, + body: &BTreeMap, + headers: Vec<(String, String)>, +) -> Result<(Option, Vec<(String, String)>)> { + let body_type = match &body_type { + None => return Ok((None, headers)), + Some(t) => t, + }; + + let (body, content_type) = match body_type.as_str() { + "binary" => (build_binary_body(&body).await?, None), + "graphql" => (build_graphql_body(&method, &body), Some("application/json".to_string())), + "application/x-www-form-urlencoded" => { + (build_form_body(&body), Some("application/x-www-form-urlencoded".to_string())) + } + "multipart/form-data" => build_multipart_body(&body, &headers).await?, + _ if body.contains_key("text") => (build_text_body(&body), None), + t => { + warn!("Unsupported body type: {}", t); + (None, None) + } + }; + + // Add or update the Content-Type header + let mut headers = headers; + if let Some(ct) = content_type { + if let Some(existing) = headers.iter_mut().find(|h| h.0.to_lowercase() == "content-type") { + existing.1 = ct; + } else { + headers.push(("Content-Type".to_string(), ct)); + } + } + + // Check if Transfer-Encoding: chunked is already set + let has_chunked_encoding = headers.iter().any(|h| { + h.0.to_lowercase() == "transfer-encoding" && h.1.to_lowercase().contains("chunked") + }); + + // Add a Content-Length header only if chunked encoding is not being used + if !has_chunked_encoding { + let content_length = match body { + Some(SendableBodyWithMeta::Bytes(ref bytes)) => Some(bytes.len()), + Some(SendableBodyWithMeta::Stream { content_length, .. }) => content_length, + None => None, + }; + + if let Some(cl) = content_length { + headers.push(("Content-Length".to_string(), cl.to_string())); + } + } + + Ok((body.map(|b| b.into()), headers)) +} + +fn build_form_body(body: &BTreeMap) -> Option { + let form_params = match body.get("form").map(|f| f.as_array()) { + Some(Some(f)) => f, + _ => return None, + }; + + let mut body = String::new(); + for p in form_params { + let enabled = get_bool(p, "enabled", true); + let name = get_str(p, "name"); + if !enabled || name.is_empty() { + continue; + } + let value = get_str(p, "value"); + if !body.is_empty() { + body.push('&'); + } + body.push_str(&urlencoding::encode(&name)); + body.push('='); + body.push_str(&urlencoding::encode(&value)); + } + + if body.is_empty() { None } else { Some(SendableBodyWithMeta::Bytes(Bytes::from(body))) } +} + +async fn build_binary_body( + body: &BTreeMap, +) -> Result> { + let file_path = match body.get("filePath").map(|f| f.as_str()) { + Some(Some(f)) => f, + _ => return Ok(None), + }; + + // Open a file for streaming + let content_length = tokio::fs::metadata(file_path) + .await + .map_err(|e| RequestError(format!("Failed to get file metadata: {}", e)))? + .len(); + + let file = tokio::fs::File::open(file_path) + .await + .map_err(|e| RequestError(format!("Failed to open file: {}", e)))?; + + Ok(Some(SendableBodyWithMeta::Stream { + data: Box::pin(file), + content_length: Some(content_length as usize), + })) +} + +fn build_text_body(body: &BTreeMap) -> Option { + let text = get_str_map(body, "text"); + if text.is_empty() { + None + } else { + Some(SendableBodyWithMeta::Bytes(Bytes::from(text.to_string()))) + } +} + +fn build_graphql_body( + method: &str, + body: &BTreeMap, +) -> Option { + let query = get_str_map(body, "query"); + let variables = get_str_map(body, "variables"); + + if method.to_lowercase() == "get" { + // GraphQL GET requests use query parameters, not a body + return None; + } + + let body = if variables.trim().is_empty() { + format!(r#"{{"query":{}}}"#, serde_json::to_string(&query).unwrap_or_default()) + } else { + format!( + r#"{{"query":{},"variables":{}}}"#, + serde_json::to_string(&query).unwrap_or_default(), + variables + ) + }; + + Some(SendableBodyWithMeta::Bytes(Bytes::from(body))) +} + +async fn build_multipart_body( + body: &BTreeMap, + headers: &Vec<(String, String)>, +) -> Result<(Option, Option)> { + let boundary = extract_boundary_from_headers(headers); + + let form_params = match body.get("form").map(|f| f.as_array()) { + Some(Some(f)) => f, + _ => return Ok((None, None)), + }; + + // Build a list of readers for streaming and calculate total content length + let mut readers: Vec = Vec::new(); + let mut has_content = false; + let mut total_size: usize = 0; + + for p in form_params { + let enabled = get_bool(p, "enabled", true); + let name = get_str(p, "name"); + if !enabled || name.is_empty() { + continue; + } + + has_content = true; + + // Add boundary delimiter + let boundary_bytes = format!("--{}\r\n", boundary).into_bytes(); + total_size += boundary_bytes.len(); + readers.push(ReaderType::Bytes(boundary_bytes)); + + let file_path = get_str(p, "file"); + let value = get_str(p, "value"); + let content_type = get_str(p, "contentType"); + + if file_path.is_empty() { + // Text field + let header = + format!("Content-Disposition: form-data; name=\"{}\"\r\n\r\n{}", name, value); + let header_bytes = header.into_bytes(); + total_size += header_bytes.len(); + readers.push(ReaderType::Bytes(header_bytes)); + } else { + // File field - validate that file exists first + if !tokio::fs::try_exists(file_path).await.unwrap_or(false) { + return Err(RequestError(format!("File not found: {}", file_path))); + } + + // Get file size for content length calculation + let file_metadata = tokio::fs::metadata(file_path) + .await + .map_err(|e| RequestError(format!("Failed to get file metadata: {}", e)))?; + let file_size = file_metadata.len() as usize; + + let filename = get_str(p, "filename"); + let filename = if filename.is_empty() { + std::path::Path::new(file_path) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("file") + } else { + filename + }; + + // Add content type + let mime_type = if !content_type.is_empty() { + content_type.to_string() + } else { + // Guess mime type from file extension + mime_guess::from_path(file_path).first_or_octet_stream().to_string() + }; + + let header = format!( + "Content-Disposition: form-data; name=\"{}\"; filename=\"{}\"\r\nContent-Type: {}\r\n\r\n", + name, filename, mime_type + ); + let header_bytes = header.into_bytes(); + total_size += header_bytes.len(); + total_size += file_size; + readers.push(ReaderType::Bytes(header_bytes)); + + // Add a file path for streaming + readers.push(ReaderType::FilePath(file_path.to_string())); + } + + let line_ending = b"\r\n".to_vec(); + total_size += line_ending.len(); + readers.push(ReaderType::Bytes(line_ending)); + } + + if has_content { + // Add the final boundary + let final_boundary = format!("--{}--\r\n", boundary).into_bytes(); + total_size += final_boundary.len(); + readers.push(ReaderType::Bytes(final_boundary)); + + let content_type = format!("multipart/form-data; boundary={}", boundary); + let stream = ChainedReader::new(readers); + Ok(( + Some(SendableBodyWithMeta::Stream { + data: Box::pin(stream), + content_length: Some(total_size), + }), + Some(content_type), + )) + } else { + Ok((None, None)) + } +} + +fn extract_boundary_from_headers(headers: &Vec<(String, String)>) -> String { + headers + .iter() + .find(|h| h.0.to_lowercase() == "content-type") + .and_then(|h| { + // Extract boundary from the Content-Type header (e.g., "multipart/form-data; boundary=xyz") + h.1.split(';') + .find(|part| part.trim().starts_with("boundary=")) + .and_then(|boundary_part| boundary_part.split('=').nth(1)) + .map(|b| b.trim().to_string()) + }) + .unwrap_or_else(|| MULTIPART_BOUNDARY.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use serde_json::json; + use std::collections::BTreeMap; + use yaak_models::models::{HttpRequest, HttpUrlParameter}; + + #[test] + fn test_build_url_no_params() { + let r = HttpRequest { + url: "https://example.com/api".to_string(), + url_parameters: vec![], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api"); + } + + #[test] + fn test_build_url_with_params() { + let r = HttpRequest { + url: "https://example.com/api".to_string(), + url_parameters: vec![ + HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }, + HttpUrlParameter { + enabled: true, + name: "baz".to_string(), + value: "qux".to_string(), + id: None, + }, + ], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?foo=bar&baz=qux"); + } + + #[test] + fn test_build_url_with_disabled_params() { + let r = HttpRequest { + url: "https://example.com/api".to_string(), + url_parameters: vec![ + HttpUrlParameter { + enabled: false, + name: "disabled".to_string(), + value: "value".to_string(), + id: None, + }, + HttpUrlParameter { + enabled: true, + name: "enabled".to_string(), + value: "value".to_string(), + id: None, + }, + ], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?enabled=value"); + } + + #[test] + fn test_build_url_with_existing_query() { + let r = HttpRequest { + url: "https://example.com/api?existing=param".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "new".to_string(), + value: "value".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?existing=param&new=value"); + } + + #[test] + fn test_build_url_with_empty_existing_query() { + let r = HttpRequest { + url: "https://example.com/api?".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "new".to_string(), + value: "value".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?new=value"); + } + + #[test] + fn test_build_url_with_special_chars() { + let r = HttpRequest { + url: "https://example.com/api".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "special chars!@#".to_string(), + value: "value with spaces & symbols".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!( + result, + "https://example.com/api?special%20chars%21%40%23=value%20with%20spaces%20%26%20symbols" + ); + } + + #[test] + fn test_build_url_adds_protocol() { + let r = HttpRequest { + url: "example.com/api".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + // ensure_proto defaults to http:// for regular domains + assert_eq!(result, "http://example.com/api?foo=bar"); + } + + #[test] + fn test_build_url_adds_https_for_dev_domain() { + let r = HttpRequest { + url: "example.dev/api".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + // .dev domains force https + assert_eq!(result, "https://example.dev/api?foo=bar"); + } + + #[test] + fn test_build_url_with_fragment() { + let r = HttpRequest { + url: "https://example.com/api#section".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?foo=bar#section"); + } + + #[test] + fn test_build_url_with_existing_query_and_fragment() { + let r = HttpRequest { + url: "https://yaak.app?foo=bar#some-hash".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "baz".to_string(), + value: "qux".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://yaak.app?foo=bar&baz=qux#some-hash"); + } + + #[test] + fn test_build_url_with_empty_query_and_fragment() { + let r = HttpRequest { + url: "https://example.com/api?#section".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?foo=bar#section"); + } + + #[test] + fn test_build_url_with_fragment_containing_special_chars() { + let r = HttpRequest { + url: "https://example.com#section/with/slashes?and=fake&query".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "real".to_string(), + value: "param".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com?real=param#section/with/slashes?and=fake&query"); + } + + #[test] + fn test_build_url_preserves_empty_fragment() { + let r = HttpRequest { + url: "https://example.com/api#".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + assert_eq!(result, "https://example.com/api?foo=bar#"); + } + + #[test] + fn test_build_url_with_multiple_fragments() { + // Testing edge case where the URL has multiple # characters (though technically invalid) + let r = HttpRequest { + url: "https://example.com#section#subsection".to_string(), + url_parameters: vec![HttpUrlParameter { + enabled: true, + name: "foo".to_string(), + value: "bar".to_string(), + id: None, + }], + ..Default::default() + }; + + let result = build_url(&r); + // Should treat everything after first # as fragment + assert_eq!(result, "https://example.com?foo=bar#section#subsection"); + } + + #[tokio::test] + async fn test_text_body() { + let mut body = BTreeMap::new(); + body.insert("text".to_string(), json!("Hello, World!")); + + let result = build_text_body(&body); + match result { + Some(SendableBodyWithMeta::Bytes(bytes)) => { + assert_eq!(bytes, Bytes::from("Hello, World!")) + } + _ => panic!("Expected Some(SendableBody::Bytes)"), + } + } + + #[tokio::test] + async fn test_text_body_empty() { + let mut body = BTreeMap::new(); + body.insert("text".to_string(), json!("")); + + let result = build_text_body(&body); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_text_body_missing() { + let body = BTreeMap::new(); + + let result = build_text_body(&body); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_form_urlencoded_body() -> Result<()> { + let mut body = BTreeMap::new(); + body.insert( + "form".to_string(), + json!([ + { "enabled": true, "name": "basic", "value": "aaa"}, + { "enabled": true, "name": "fUnkey Stuff!$*#(", "value": "*)%&#$)@ *$#)@&"}, + { "enabled": false, "name": "disabled", "value": "won't show"}, + ]), + ); + + let result = build_form_body(&body); + match result { + Some(SendableBodyWithMeta::Bytes(bytes)) => { + let expected = "basic=aaa&fUnkey%20Stuff%21%24%2A%23%28=%2A%29%25%26%23%24%29%40%20%2A%24%23%29%40%26"; + assert_eq!(bytes, Bytes::from(expected)); + } + _ => panic!("Expected Some(SendableBody::Bytes)"), + } + Ok(()) + } + + #[tokio::test] + async fn test_form_urlencoded_body_missing_form() { + let body = BTreeMap::new(); + let result = build_form_body(&body); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_binary_body() -> Result<()> { + let mut body = BTreeMap::new(); + body.insert("filePath".to_string(), json!("./tests/test.txt")); + + let result = build_binary_body(&body).await?; + assert!(matches!(result, Some(SendableBodyWithMeta::Stream { .. }))); + Ok(()) + } + + #[tokio::test] + async fn test_binary_body_file_not_found() { + let mut body = BTreeMap::new(); + body.insert("filePath".to_string(), json!("./nonexistent/file.txt")); + + let result = build_binary_body(&body).await; + assert!(result.is_err()); + if let Err(e) = result { + assert!(matches!(e, RequestError(_))); + } + } + + #[tokio::test] + async fn test_graphql_body_with_variables() { + let mut body = BTreeMap::new(); + body.insert("query".to_string(), json!("{ user(id: $id) { name } }")); + body.insert("variables".to_string(), json!(r#"{"id": "123"}"#)); + + let result = build_graphql_body("POST", &body); + match result { + Some(SendableBodyWithMeta::Bytes(bytes)) => { + let expected = + r#"{"query":"{ user(id: $id) { name } }","variables":{"id": "123"}}"#; + assert_eq!(bytes, Bytes::from(expected)); + } + _ => panic!("Expected Some(SendableBody::Bytes)"), + } + } + + #[tokio::test] + async fn test_graphql_body_without_variables() { + let mut body = BTreeMap::new(); + body.insert("query".to_string(), json!("{ users { name } }")); + body.insert("variables".to_string(), json!("")); + + let result = build_graphql_body("POST", &body); + match result { + Some(SendableBodyWithMeta::Bytes(bytes)) => { + let expected = r#"{"query":"{ users { name } }"}"#; + assert_eq!(bytes, Bytes::from(expected)); + } + _ => panic!("Expected Some(SendableBody::Bytes)"), + } + } + + #[tokio::test] + async fn test_graphql_body_get_method() { + let mut body = BTreeMap::new(); + body.insert("query".to_string(), json!("{ users { name } }")); + + let result = build_graphql_body("GET", &body); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_multipart_body_text_fields() -> Result<()> { + let mut body = BTreeMap::new(); + body.insert( + "form".to_string(), + json!([ + { "enabled": true, "name": "field1", "value": "value1", "file": "" }, + { "enabled": true, "name": "field2", "value": "value2", "file": "" }, + { "enabled": false, "name": "disabled", "value": "won't show", "file": "" }, + ]), + ); + + let (result, content_type) = build_multipart_body(&body, &vec![]).await?; + assert!(content_type.is_some()); + + match result { + Some(SendableBodyWithMeta::Stream { data: mut stream, content_length }) => { + // Read the entire stream to verify content + let mut buf = Vec::new(); + use tokio::io::AsyncReadExt; + stream.read_to_end(&mut buf).await.expect("Failed to read stream"); + let body_str = String::from_utf8_lossy(&buf); + assert_eq!( + body_str, + "--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"field1\"\r\n\r\nvalue1\r\n--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"field2\"\r\n\r\nvalue2\r\n--------YaakFormBoundary--\r\n", + ); + assert_eq!(content_length, Some(body_str.len())); + } + _ => panic!("Expected Some(SendableBody::Stream)"), + } + + assert_eq!( + content_type.unwrap(), + format!("multipart/form-data; boundary={}", MULTIPART_BOUNDARY) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_multipart_body_with_file() -> Result<()> { + let mut body = BTreeMap::new(); + body.insert( + "form".to_string(), + json!([ + { "enabled": true, "name": "file_field", "file": "./tests/test.txt", "filename": "custom.txt", "contentType": "text/plain" }, + ]), + ); + + let (result, content_type) = build_multipart_body(&body, &vec![]).await?; + assert!(content_type.is_some()); + + match result { + Some(SendableBodyWithMeta::Stream { data: mut stream, content_length }) => { + // Read the entire stream to verify content + let mut buf = Vec::new(); + use tokio::io::AsyncReadExt; + stream.read_to_end(&mut buf).await.expect("Failed to read stream"); + let body_str = String::from_utf8_lossy(&buf); + assert_eq!( + body_str, + "--------YaakFormBoundary\r\nContent-Disposition: form-data; name=\"file_field\"; filename=\"custom.txt\"\r\nContent-Type: text/plain\r\n\r\nThis is a test file!\n\r\n--------YaakFormBoundary--\r\n" + ); + assert_eq!(content_length, Some(body_str.len())); + } + _ => panic!("Expected Some(SendableBody::Stream)"), + } + + assert_eq!( + content_type.unwrap(), + format!("multipart/form-data; boundary={}", MULTIPART_BOUNDARY) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_multipart_body_empty() -> Result<()> { + let body = BTreeMap::new(); + let (result, content_type) = build_multipart_body(&body, &vec![]).await?; + assert!(result.is_none()); + assert_eq!(content_type, None); + Ok(()) + } + + #[test] + fn test_extract_boundary_from_headers_with_custom_boundary() { + let headers = vec![( + "Content-Type".to_string(), + "multipart/form-data; boundary=customBoundary123".to_string(), + )]; + let boundary = extract_boundary_from_headers(&headers); + assert_eq!(boundary, "customBoundary123"); + } + + #[test] + fn test_extract_boundary_from_headers_default() { + let headers = vec![("Accept".to_string(), "*/*".to_string())]; + let boundary = extract_boundary_from_headers(&headers); + assert_eq!(boundary, MULTIPART_BOUNDARY); + } + + #[test] + fn test_extract_boundary_from_headers_no_boundary_in_content_type() { + let headers = vec![("Content-Type".to_string(), "multipart/form-data".to_string())]; + let boundary = extract_boundary_from_headers(&headers); + assert_eq!(boundary, MULTIPART_BOUNDARY); + } + + #[test] + fn test_extract_boundary_case_insensitive() { + let headers = vec![( + "Content-Type".to_string(), + "multipart/form-data; boundary=myBoundary".to_string(), + )]; + let boundary = extract_boundary_from_headers(&headers); + assert_eq!(boundary, "myBoundary"); + } + + #[tokio::test] + async fn test_no_content_length_with_chunked_encoding() -> Result<()> { + let mut body = BTreeMap::new(); + body.insert("text".to_string(), json!("Hello, World!")); + + // Headers with Transfer-Encoding: chunked + let headers = vec![("Transfer-Encoding".to_string(), "chunked".to_string())]; + + let (_, result_headers) = + build_body("POST", &Some("text/plain".to_string()), &body, headers).await?; + + // Verify that Content-Length is NOT present when Transfer-Encoding: chunked is set + let has_content_length = + result_headers.iter().any(|h| h.0.to_lowercase() == "content-length"); + assert!(!has_content_length, "Content-Length should not be present with chunked encoding"); + + // Verify that the Transfer-Encoding header is still present + let has_chunked = result_headers.iter().any(|h| { + h.0.to_lowercase() == "transfer-encoding" && h.1.to_lowercase().contains("chunked") + }); + assert!(has_chunked, "Transfer-Encoding: chunked should be preserved"); + + Ok(()) + } + + #[tokio::test] + async fn test_content_length_without_chunked_encoding() -> Result<()> { + let mut body = BTreeMap::new(); + body.insert("text".to_string(), json!("Hello, World!")); + + // Headers without Transfer-Encoding: chunked + let headers = vec![]; + + let (_, result_headers) = + build_body("POST", &Some("text/plain".to_string()), &body, headers).await?; + + // Verify that Content-Length IS present when Transfer-Encoding: chunked is NOT set + let content_length_header = + result_headers.iter().find(|h| h.0.to_lowercase() == "content-length"); + assert!( + content_length_header.is_some(), + "Content-Length should be present without chunked encoding" + ); + assert_eq!( + content_length_header.unwrap().1, + "13", + "Content-Length should match the body size" + ); + + Ok(()) + } +} diff --git a/src-tauri/yaak-http/tests/test.txt b/src-tauri/yaak-http/tests/test.txt new file mode 100644 index 00000000..c66d471e --- /dev/null +++ b/src-tauri/yaak-http/tests/test.txt @@ -0,0 +1 @@ +This is a test file! diff --git a/src-tauri/yaak-models/bindings/gen_models.ts b/src-tauri/yaak-models/bindings/gen_models.ts index ce16e13a..055e5e35 100644 --- a/src-tauri/yaak-models/bindings/gen_models.ts +++ b/src-tauri/yaak-models/bindings/gen_models.ts @@ -38,7 +38,7 @@ export type HttpRequest = { model: "http_request", id: string, createdAt: string export type HttpRequestHeader = { enabled?: boolean, name: string, value: string, id?: string, }; -export type HttpResponse = { model: "http_response", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, bodyPath: string | null, contentLength: number | null, elapsed: number, elapsedHeaders: number, error: string | null, headers: Array, remoteAddr: string | null, status: number, statusReason: string | null, state: HttpResponseState, url: string, version: string | null, }; +export type HttpResponse = { model: "http_response", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, bodyPath: string | null, contentLength: number | null, contentLengthCompressed: number | null, elapsed: number, elapsedHeaders: number, error: string | null, headers: Array, remoteAddr: string | null, requestHeaders: Array, status: number, statusReason: string | null, state: HttpResponseState, url: string, version: string | null, }; export type HttpResponseHeader = { name: string, value: string, }; diff --git a/src-tauri/yaak-models/migrations/20251219074602_default-workspace-headers.sql b/src-tauri/yaak-models/migrations/20251219074602_default-workspace-headers.sql new file mode 100644 index 00000000..8793b2a5 --- /dev/null +++ b/src-tauri/yaak-models/migrations/20251219074602_default-workspace-headers.sql @@ -0,0 +1,15 @@ +-- Add default User-Agent header to workspaces that don't already have one (case-insensitive check) +UPDATE workspaces +SET headers = json_insert(headers, '$[#]', json('{"enabled":true,"name":"User-Agent","value":"yaak"}')) +WHERE NOT EXISTS ( + SELECT 1 FROM json_each(workspaces.headers) + WHERE LOWER(json_extract(value, '$.name')) = 'user-agent' +); + +-- Add default Accept header to workspaces that don't already have one (case-insensitive check) +UPDATE workspaces +SET headers = json_insert(headers, '$[#]', json('{"enabled":true,"name":"Accept","value":"*/*"}')) +WHERE NOT EXISTS ( + SELECT 1 FROM json_each(workspaces.headers) + WHERE LOWER(json_extract(value, '$.name')) = 'accept' +); diff --git a/src-tauri/yaak-models/migrations/20251220000000_response-request-headers.sql b/src-tauri/yaak-models/migrations/20251220000000_response-request-headers.sql new file mode 100644 index 00000000..36e73aeb --- /dev/null +++ b/src-tauri/yaak-models/migrations/20251220000000_response-request-headers.sql @@ -0,0 +1,3 @@ +-- Add request_headers and content_length_compressed columns to http_responses table +ALTER TABLE http_responses ADD COLUMN request_headers TEXT NOT NULL DEFAULT '[]'; +ALTER TABLE http_responses ADD COLUMN content_length_compressed INTEGER; diff --git a/src-tauri/yaak-models/src/error.rs b/src-tauri/yaak-models/src/error.rs index 466e75b6..f92785a1 100644 --- a/src-tauri/yaak-models/src/error.rs +++ b/src-tauri/yaak-models/src/error.rs @@ -18,7 +18,7 @@ pub enum Error { #[error("Model serialization error: {0}")] ModelSerializationError(String), - #[error("Model error: {0}")] + #[error("HTTP error: {0}")] GenericError(String), #[error("DB Migration Failed: {0}")] diff --git a/src-tauri/yaak-models/src/models.rs b/src-tauri/yaak-models/src/models.rs index 9b773b83..f3f96dac 100644 --- a/src-tauri/yaak-models/src/models.rs +++ b/src-tauri/yaak-models/src/models.rs @@ -1323,11 +1323,13 @@ pub struct HttpResponse { pub body_path: Option, pub content_length: Option, + pub content_length_compressed: Option, pub elapsed: i32, pub elapsed_headers: i32, pub error: Option, pub headers: Vec, pub remote_addr: Option, + pub request_headers: Vec, pub status: i32, pub status_reason: Option, pub state: HttpResponseState, @@ -1368,11 +1370,13 @@ impl UpsertModelInfo for HttpResponse { (WorkspaceId, self.workspace_id.into()), (BodyPath, self.body_path.into()), (ContentLength, self.content_length.into()), + (ContentLengthCompressed, self.content_length_compressed.into()), (Elapsed, self.elapsed.into()), (ElapsedHeaders, self.elapsed_headers.into()), (Error, self.error.into()), (Headers, serde_json::to_string(&self.headers)?.into()), (RemoteAddr, self.remote_addr.into()), + (RequestHeaders, serde_json::to_string(&self.request_headers)?.into()), (State, serde_json::to_value(self.state)?.as_str().into()), (Status, self.status.into()), (StatusReason, self.status_reason.into()), @@ -1386,11 +1390,13 @@ impl UpsertModelInfo for HttpResponse { HttpResponseIden::UpdatedAt, HttpResponseIden::BodyPath, HttpResponseIden::ContentLength, + HttpResponseIden::ContentLengthCompressed, HttpResponseIden::Elapsed, HttpResponseIden::ElapsedHeaders, HttpResponseIden::Error, HttpResponseIden::Headers, HttpResponseIden::RemoteAddr, + HttpResponseIden::RequestHeaders, HttpResponseIden::State, HttpResponseIden::Status, HttpResponseIden::StatusReason, @@ -1415,6 +1421,7 @@ impl UpsertModelInfo for HttpResponse { error: r.get("error")?, url: r.get("url")?, content_length: r.get("content_length")?, + content_length_compressed: r.get("content_length_compressed").unwrap_or_default(), version: r.get("version")?, elapsed: r.get("elapsed")?, elapsed_headers: r.get("elapsed_headers")?, @@ -1424,6 +1431,10 @@ impl UpsertModelInfo for HttpResponse { state: serde_json::from_str(format!(r#""{state}""#).as_str()).unwrap(), body_path: r.get("body_path")?, headers: serde_json::from_str(headers.as_str()).unwrap_or_default(), + request_headers: serde_json::from_str( + r.get::<_, String>("request_headers").unwrap_or_default().as_str(), + ) + .unwrap_or_default(), }) } } diff --git a/src-tauri/yaak-plugins/bindings/gen_models.ts b/src-tauri/yaak-plugins/bindings/gen_models.ts index 6b2eb5c8..ebe00460 100644 --- a/src-tauri/yaak-plugins/bindings/gen_models.ts +++ b/src-tauri/yaak-plugins/bindings/gen_models.ts @@ -12,7 +12,7 @@ export type HttpRequest = { model: "http_request", id: string, createdAt: string export type HttpRequestHeader = { enabled?: boolean, name: string, value: string, id?: string, }; -export type HttpResponse = { model: "http_response", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, bodyPath: string | null, contentLength: number | null, elapsed: number, elapsedHeaders: number, error: string | null, headers: Array, remoteAddr: string | null, status: number, statusReason: string | null, state: HttpResponseState, url: string, version: string | null, }; +export type HttpResponse = { model: "http_response", id: string, createdAt: string, updatedAt: string, workspaceId: string, requestId: string, bodyPath: string | null, contentLength: number | null, contentLengthCompressed: number | null, elapsed: number, elapsedHeaders: number, error: string | null, headers: Array, remoteAddr: string | null, requestHeaders: Array, status: number, statusReason: string | null, state: HttpResponseState, url: string, version: string | null, }; export type HttpResponseHeader = { name: string, value: string, }; diff --git a/src-tauri/yaak-templates/src/renderer.rs b/src-tauri/yaak-templates/src/renderer.rs index 73ccee55..495e1120 100644 --- a/src-tauri/yaak-templates/src/renderer.rs +++ b/src-tauri/yaak-templates/src/renderer.rs @@ -77,6 +77,12 @@ pub struct RenderOptions { pub error_behavior: RenderErrorBehavior, } +impl RenderOptions { + pub fn throw() -> Self { + Self { error_behavior: RenderErrorBehavior::Throw } + } +} + impl RenderErrorBehavior { pub fn handle(&self, r: Result) -> Result { match (self, r) { diff --git a/src-tauri/yaak-ws/src/commands.rs b/src-tauri/yaak-ws/src/commands.rs index a61ad5ea..be1073af 100644 --- a/src-tauri/yaak-ws/src/commands.rs +++ b/src-tauri/yaak-ws/src/commands.rs @@ -216,7 +216,7 @@ pub(crate) async fn connect( &UpdateSource::from_window(&window), )?; - let (mut url, url_parameters) = apply_path_placeholders(&request.url, request.url_parameters); + let (mut url, url_parameters) = apply_path_placeholders(&request.url, &request.url_parameters); if !url.starts_with("ws://") && !url.starts_with("wss://") { url.insert_str(0, "ws://"); } diff --git a/src-web/components/ExportDataDialog.tsx b/src-web/components/ExportDataDialog.tsx index 0c7c1ec7..89942fed 100644 --- a/src-web/components/ExportDataDialog.tsx +++ b/src-web/components/ExportDataDialog.tsx @@ -128,7 +128,7 @@ function ExportDataDialogContent({ ))} - + - + ) : ( diff --git a/src-web/components/HeadersEditor.tsx b/src-web/components/HeadersEditor.tsx index 7e6d76a6..47f036e1 100644 --- a/src-web/components/HeadersEditor.tsx +++ b/src-web/components/HeadersEditor.tsx @@ -34,11 +34,11 @@ export function HeadersEditor({ const validInheritedHeaders = inheritedHeaders?.filter((pair) => pair.enabled && (pair.name || pair.value)) ?? []; return ( -
+
{validInheritedHeaders.length > 0 ? ( Inherited diff --git a/src-web/components/HttpResponsePane.tsx b/src-web/components/HttpResponsePane.tsx index e29bda0e..86c68ee4 100644 --- a/src-web/components/HttpResponsePane.tsx +++ b/src-web/components/HttpResponsePane.tsx @@ -76,7 +76,8 @@ export function HttpResponsePane({ style, className, activeRequestId }: Props) { label: 'Headers', rightSlot: ( h.name && h.value).length ?? 0} + count2={activeResponse?.headers.length ?? 0} + count={activeResponse?.requestHeaders.length ?? 0} /> ), }, @@ -85,7 +86,13 @@ export function HttpResponsePane({ style, className, activeRequestId }: Props) { label: 'Info', }, ], - [activeResponse?.headers, mimeType, setViewMode, viewMode], + [ + activeResponse?.headers, + mimeType, + setViewMode, + viewMode, + activeResponse?.requestHeaders.length, + ], ); const activeTab = activeTabs?.[activeRequestId]; const setActiveTab = useCallback( @@ -133,7 +140,10 @@ export function HttpResponsePane({ style, className, activeRequestId }: Props) { - +
- {activeResponse?.error ? ( - - {activeResponse.error} - - ) : ( - - - - - - {activeResponse.state === 'initialized' ? ( - - - - - Sending Request - - - - - ) : activeResponse.state === 'closed' && - activeResponse.contentLength === 0 ? ( - Empty - ) : mimeType?.match(/^text\/event-stream/i) && viewMode === 'pretty' ? ( - - ) : mimeType?.match(/^image\/svg/) ? ( - - ) : mimeType?.match(/^image/i) ? ( - - ) : mimeType?.match(/^audio/i) ? ( - - ) : mimeType?.match(/^video/i) ? ( - - ) : mimeType?.match(/pdf/i) ? ( - - ) : mimeType?.match(/csv|tab-separated/i) ? ( - - ) : ( - - )} - - - - - - - - - - - - )} +
+ {activeResponse?.error && ( + + {activeResponse.error} + + )} + {/* Show tabs if we have any data (headers, body, etc.) even if there's an error */} + {(activeResponse?.headers.length > 0 || + activeResponse?.bodyPath || + !activeResponse?.error) && ( + + + + + + {activeResponse.state === 'initialized' ? ( + + + + + Sending Request + + + + + ) : activeResponse.state === 'closed' && + activeResponse.contentLength === 0 ? ( + Empty + ) : mimeType?.match(/^text\/event-stream/i) && viewMode === 'pretty' ? ( + + ) : mimeType?.match(/^image\/svg/) ? ( + + ) : mimeType?.match(/^image/i) ? ( + + ) : mimeType?.match(/^audio/i) ? ( + + ) : mimeType?.match(/^video/i) ? ( + + ) : mimeType?.match(/pdf/i) ? ( + + ) : mimeType?.match(/csv|tab-separated/i) ? ( + + ) : ( + + )} + + + + + + + + + + + + )} +
)}
diff --git a/src-web/components/ResponseHeaders.tsx b/src-web/components/ResponseHeaders.tsx index 04bffc07..d4ef9310 100644 --- a/src-web/components/ResponseHeaders.tsx +++ b/src-web/components/ResponseHeaders.tsx @@ -1,5 +1,7 @@ import type { HttpResponse } from '@yaakapp-internal/models'; import { useMemo } from 'react'; +import { CountBadge } from './core/CountBadge'; +import { DetailsBanner } from './core/DetailsBanner'; import { KeyValueRow, KeyValueRows } from './core/KeyValueRow'; interface Props { @@ -7,20 +9,57 @@ interface Props { } export function ResponseHeaders({ response }: Props) { - const sortedHeaders = useMemo( - () => [...response.headers].sort((a, b) => a.name.localeCompare(b.name)), + const responseHeaders = useMemo( + () => + [...response.headers].sort((a, b) => + a.name.toLocaleLowerCase().localeCompare(b.name.toLocaleLowerCase()), + ), [response.headers], ); + const requestHeaders = useMemo( + () => + [...response.requestHeaders].sort((a, b) => + a.name.toLocaleLowerCase().localeCompare(b.name.toLocaleLowerCase()), + ), + [response.requestHeaders], + ); return ( -
- - {sortedHeaders.map((h, i) => ( - // biome-ignore lint/suspicious/noArrayIndexKey: none - - {h.value} - - ))} - +
+ + Response + + } + > + + {responseHeaders.map((h, i) => ( + // biome-ignore lint/suspicious/noArrayIndexKey: none + + {h.value} + + ))} + + + + Request + + } + > + + {requestHeaders.map((h, i) => ( + // biome-ignore lint/suspicious/noArrayIndexKey: none + + {h.value} + + ))} + +
); } diff --git a/src-web/components/Settings/SettingsCertificates.tsx b/src-web/components/Settings/SettingsCertificates.tsx index d7f069ea..e6081e72 100644 --- a/src-web/components/Settings/SettingsCertificates.tsx +++ b/src-web/components/Settings/SettingsCertificates.tsx @@ -53,7 +53,7 @@ function CertificateEditor({ certificate, index, onUpdate, onRemove }: Certifica return ( diff --git a/src-web/components/core/Button.tsx b/src-web/components/core/Button.tsx index c7fa3097..3d08249a 100644 --- a/src-web/components/core/Button.tsx +++ b/src-web/components/core/Button.tsx @@ -61,7 +61,6 @@ export const Button = forwardRef(function Button 'x-theme-button', `x-theme-button--${variant}`, `x-theme-button--${variant}--${color}`, - 'text-text', 'border', // They all have borders to ensure the same width 'max-w-full min-w-0', // Help with truncation 'hocus:opacity-100', // Force opacity for certain hover effects @@ -81,7 +80,7 @@ export const Button = forwardRef(function Button variant === 'solid' && color === 'custom' && 'focus-visible:outline-2 outline-border-focus', variant === 'solid' && color !== 'custom' && - 'enabled:hocus:text-text enabled:hocus:bg-surface-highlight outline-border-subtle', + 'text-text enabled:hocus:text-text enabled:hocus:bg-surface-highlight outline-border-subtle', variant === 'solid' && color !== 'custom' && color !== 'default' && 'bg-surface', // Borders diff --git a/src-web/components/core/CountBadge.tsx b/src-web/components/core/CountBadge.tsx index a719c275..86c3515d 100644 --- a/src-web/components/core/CountBadge.tsx +++ b/src-web/components/core/CountBadge.tsx @@ -3,11 +3,12 @@ import classNames from 'classnames'; interface Props { count: number | true; + count2?: number | true; className?: string; color?: Color; } -export function CountBadge({ count, className, color }: Props) { +export function CountBadge({ count, count2, className, color }: Props) { if (count === 0) return null; return (
+ / + {count2 === true ? ( +
+ ) : ( + count2 + )} + + )}
); } diff --git a/src-web/components/core/DetailsBanner.tsx b/src-web/components/core/DetailsBanner.tsx index 32368c32..df0f3c4c 100644 --- a/src-web/components/core/DetailsBanner.tsx +++ b/src-web/components/core/DetailsBanner.tsx @@ -1,19 +1,48 @@ import classNames from 'classnames'; +import { atom, useAtom } from 'jotai'; import type { HTMLAttributes, ReactNode } from 'react'; +import { useMemo } from 'react'; +import { atomWithKVStorage } from '../../lib/atoms/atomWithKVStorage'; import type { BannerProps } from './Banner'; import { Banner } from './Banner'; interface Props extends HTMLAttributes { summary: ReactNode; color?: BannerProps['color']; - open?: boolean; + defaultOpen?: boolean; + storageKey?: string; } -export function DetailsBanner({ className, color, summary, children, ...extraProps }: Props) { +export function DetailsBanner({ + className, + color, + summary, + children, + defaultOpen, + storageKey, + ...extraProps +}: Props) { + // biome-ignore lint/correctness/useExhaustiveDependencies: We only want to recompute the atom when storageKey changes + const openAtom = useMemo( + () => + storageKey + ? atomWithKVStorage(['details_banner', storageKey], defaultOpen ?? false) + : atom(defaultOpen ?? false), + [storageKey], + ); + + const [isOpen, setIsOpen] = useAtom(openAtom); + + const handleToggle = (e: React.SyntheticEvent) => { + if (storageKey) { + setIsOpen(e.currentTarget.open); + } + }; + return ( -
- +
+
+ {formatSize(contentLength)} ); diff --git a/src-web/lib/data/encodings.ts b/src-web/lib/data/encodings.ts index 0c18fb7d..804aa496 100644 --- a/src-web/lib/data/encodings.ts +++ b/src-web/lib/data/encodings.ts @@ -1 +1 @@ -export const encodings = ['*', 'gzip', 'compress', 'deflate', 'br', 'identity']; +export const encodings = ['*', 'gzip', 'compress', 'deflate', 'br', 'zstd', 'identity'];