fix: resolve async/await issues in file transfer path validation

- Changed get_all_allowed_paths() to async to properly await library manager operations
- Changed is_path_allowed() to async and fixed RwLock held across await points
- Changed validate_path_access() to async for consistency
- Replaced blocking_read() and block_on() with proper async/await
- Clone RwLock data before await to ensure future is Send
- Updated all call sites to use .await
- Improved error logging with structured tracing

This fixes compilation errors where futures were not Send due to
RwLockReadGuard being held across await points, and prevents potential
deadlocks from blocking operations in async contexts.

Fixes the file_transfer_test which now passes successfully.
This commit is contained in:
Jamie Pine
2025-12-30 10:50:30 -08:00
parent 59686730b3
commit a292cc31cc
2 changed files with 99 additions and 57 deletions

View File

@@ -365,30 +365,29 @@ impl FileTransferProtocolHandler {
}
/// Get all allowed paths by combining static allowed_paths with dynamic locations.
/// This queries all libraries for their registered locations.
fn get_all_allowed_paths(&self) -> Vec<PathBuf> {
/// This queries all libraries for their registered locations asynchronously.
async fn get_all_allowed_paths(&self) -> Vec<PathBuf> {
let mut paths = Vec::new();
// Add statically configured allowed paths
{
// Add statically configured allowed paths (clone to avoid holding lock across await)
let static_paths = {
let allowed = self.allowed_paths.read().unwrap();
paths.extend(allowed.clone());
}
allowed.clone()
};
paths.extend(static_paths);
// Add dynamic location paths from all libraries via CoreContext
if let Some(ctx) = &self.core_context {
let library_manager_guard = ctx.library_manager.blocking_read();
let library_manager_guard = ctx.library_manager.read().await;
if let Some(library_manager) = library_manager_guard.as_ref() {
// Get all active libraries using tokio's block_on (safe in async context)
let library_list: Vec<std::sync::Arc<crate::library::Library>> =
tokio::runtime::Handle::current().block_on(library_manager.list());
// Get all active libraries
let library_list = library_manager.list().await;
for library in library_list {
// Get locations for this library using LocationManager
let location_manager =
crate::location::LocationManager::new((*ctx.events).clone());
if let Ok(locations) = tokio::runtime::Handle::current()
.block_on(location_manager.list_locations(&library))
{
if let Ok(locations) = location_manager.list_locations(&library).await {
for loc in locations {
paths.push(loc.path.clone());
}
@@ -402,7 +401,7 @@ impl FileTransferProtocolHandler {
/// Check if a path is within one of the allowed paths.
/// Uses canonicalization to prevent traversal attacks.
fn is_path_allowed(&self, path: &std::path::Path) -> bool {
async fn is_path_allowed(&self, path: &std::path::Path) -> bool {
// Canonicalize the target path to resolve symlinks and `..`
let canonical_path = match path.canonicalize() {
Ok(p) => p,
@@ -411,16 +410,27 @@ impl FileTransferProtocolHandler {
if let Some(parent) = path.parent() {
match parent.canonicalize() {
Ok(p) => p,
Err(_) => return false, // Parent doesn't exist
Err(e) => {
tracing::warn!(
path = ?path,
error = %e,
"File transfer path validation failed: parent directory doesn't exist"
);
return false; // Parent doesn't exist
}
}
} else {
tracing::warn!(
path = ?path,
"File transfer path validation failed: no parent directory"
);
return false; // No parent (root path)
}
}
};
// Get all allowed paths (static + dynamic from locations)
let allowed_paths = self.get_all_allowed_paths();
let allowed_paths = self.get_all_allowed_paths().await;
// If no allowed paths are configured and no context, deny all (fail-safe)
if allowed_paths.is_empty() {
@@ -442,9 +452,9 @@ impl FileTransferProtocolHandler {
}
tracing::warn!(
"Path {:?} is not within any allowed location. Allowed: {:?}",
path,
allowed_paths.iter().take(5).collect::<Vec<_>>() // Log first 5 for brevity
path = ?path,
allowed_paths = ?allowed_paths.iter().take(5).collect::<Vec<_>>(),
"File transfer denied: path is not within any allowed location"
);
false
}
@@ -668,7 +678,7 @@ impl FileTransferProtocolHandler {
{
// SECURITY: Validate destination path is within allowed locations
let dest_path = std::path::Path::new(&destination_path);
if !self.is_path_allowed(dest_path) {
if !self.is_path_allowed(dest_path).await {
tracing::warn!(
path = %destination_path,
from_device = %from_device,
@@ -984,7 +994,7 @@ impl FileTransferProtocolHandler {
// Validate destination path is within allowed locations
// This prevents arbitrary file write attacks from malicious peers.
let dest_path_buf = PathBuf::from(&destination_path);
if !self.is_path_allowed(&dest_path_buf) {
if !self.is_path_allowed(&dest_path_buf).await {
self.logger
.warn(&format!(
"Transfer {} rejected: destination path {:?} is not within allowed locations",
@@ -1151,7 +1161,7 @@ impl FileTransferProtocolHandler {
/// Validate that a path is safe to access for PULL requests.
/// Prevents directory traversal attacks and enforces access boundaries.
/// SECURITY: Only allows access to files within registered locations.
fn validate_path_access(&self, path: &std::path::Path, _requested_by: Uuid) -> bool {
async fn validate_path_access(&self, path: &std::path::Path, _requested_by: Uuid) -> bool {
// Normalize path to prevent directory traversal.
// canonicalize() resolves all symlinks and `..` components.
let normalized = match path.canonicalize() {
@@ -1166,7 +1176,7 @@ impl FileTransferProtocolHandler {
// Validate path is within allowed locations
// This prevents arbitrary file read attacks from malicious peers.
if !self.is_path_allowed(&normalized) {
if !self.is_path_allowed(&normalized).await {
tracing::warn!(
"Path access denied: {:?} is not within allowed locations",
path
@@ -1197,7 +1207,7 @@ impl FileTransferProtocolHandler {
.await;
// Security validation
if !self.validate_path_access(&source_path, requested_by) {
if !self.validate_path_access(&source_path, requested_by).await {
self.logger
.warn(&format!(
"PULL request {} rejected: access denied for path {}",
@@ -1880,8 +1890,8 @@ mod tests {
// Path validation security tests
#[test]
fn test_is_path_allowed_rejects_paths_outside_allowed_locations() {
#[tokio::test]
async fn test_is_path_allowed_rejects_paths_outside_allowed_locations() {
let logger = Arc::new(SilentLogger);
let handler = FileTransferProtocolHandler::new_default(logger);
@@ -1893,7 +1903,7 @@ mod tests {
// Test: Path outside allowed location should be REJECTED
let outside_path = std::path::Path::new("/etc/passwd");
assert!(
!handler.is_path_allowed(outside_path),
!handler.is_path_allowed(outside_path).await,
"Paths outside allowed locations must be rejected"
);
@@ -1902,7 +1912,7 @@ mod tests {
{
let system_path = std::path::Path::new("C:\\Windows\\System32\\config\\SAM");
assert!(
!handler.is_path_allowed(system_path),
!handler.is_path_allowed(system_path).await,
"System paths must be rejected"
);
}
@@ -1911,8 +1921,8 @@ mod tests {
std::fs::remove_dir_all(&temp_dir).ok();
}
#[test]
fn test_is_path_allowed_accepts_paths_inside_allowed_locations() {
#[tokio::test]
async fn test_is_path_allowed_accepts_paths_inside_allowed_locations() {
let logger = Arc::new(SilentLogger);
let handler = FileTransferProtocolHandler::new_default(logger);
@@ -1926,7 +1936,7 @@ mod tests {
// Test: Path inside allowed location should be ACCEPTED
assert!(
handler.is_path_allowed(&inner_path),
handler.is_path_allowed(&inner_path).await,
"Paths inside allowed locations should be accepted"
);
@@ -1934,8 +1944,8 @@ mod tests {
std::fs::remove_dir_all(&temp_dir).ok();
}
#[test]
fn test_is_path_allowed_rejects_traversal_attempts() {
#[tokio::test]
async fn test_is_path_allowed_rejects_traversal_attempts() {
let logger = Arc::new(SilentLogger);
let handler = FileTransferProtocolHandler::new_default(logger);
@@ -1948,7 +1958,7 @@ mod tests {
// Note: canonicalize() will resolve this, but if it resolves outside, it's rejected
let traversal_path = temp_dir.join("..").join("..").join("etc").join("passwd");
assert!(
!handler.is_path_allowed(&traversal_path),
!handler.is_path_allowed(&traversal_path).await,
"Path traversal attempts must be rejected"
);
@@ -1956,21 +1966,21 @@ mod tests {
std::fs::remove_dir_all(&temp_dir).ok();
}
#[test]
fn test_is_path_allowed_denies_all_when_no_paths_configured() {
#[tokio::test]
async fn test_is_path_allowed_denies_all_when_no_paths_configured() {
let logger = Arc::new(SilentLogger);
let handler = FileTransferProtocolHandler::new_default(logger);
// Don't configure any allowed paths - this should deny ALL access (fail-safe)
let any_path = std::env::temp_dir().join("some_file.txt");
assert!(
!handler.is_path_allowed(&any_path),
!handler.is_path_allowed(&any_path).await,
"When no allowed paths configured, all access should be denied"
);
}
#[test]
fn test_add_allowed_path_works() {
#[tokio::test]
async fn test_add_allowed_path_works() {
let logger = Arc::new(SilentLogger);
let handler = FileTransferProtocolHandler::new_default(logger);
@@ -1980,14 +1990,14 @@ mod tests {
std::fs::write(&file_path, "content").ok();
// Initially denied
assert!(!handler.is_path_allowed(&file_path));
assert!(!handler.is_path_allowed(&file_path).await);
// Add the path
handler.add_allowed_path(temp_dir.clone());
// Now allowed
assert!(
handler.is_path_allowed(&file_path),
handler.is_path_allowed(&file_path).await,
"add_allowed_path should permit access to paths within the added directory"
);

View File

@@ -51,6 +51,25 @@ async fn alice_file_transfer_scenario() {
tokio::time::sleep(Duration::from_secs(3)).await;
println!("Alice: Networking initialized successfully");
// Create directory for received files BEFORE adding as allowed path
std::fs::create_dir_all("/tmp/received_files").unwrap();
// Add allowed path for file transfers (security requirement from PR #2944)
if let Some(networking) = core.networking() {
let protocol_registry = networking.protocol_registry();
let registry = protocol_registry.read().await;
if let Some(handler) = registry.get_handler("file_transfer") {
if let Some(ft_handler) =
handler
.as_any()
.downcast_ref::<sd_core::service::network::protocol::FileTransferProtocolHandler>(
) {
ft_handler.add_allowed_path(std::path::PathBuf::from("/tmp/received_files"));
println!("Alice: Added /tmp/received_files as allowed path");
}
}
}
// Create a library for job dispatch (required for file transfers)
println!("Alice: Creating library for file transfer jobs...");
let _library = core
@@ -97,20 +116,18 @@ async fn alice_file_transfer_scenario() {
println!("Alice: Waiting for Bob to connect...");
let mut attempts = 0;
let max_attempts = 45; // 45 seconds
let mut receiver_device_id = None;
loop {
let receiver_id = loop {
tokio::time::sleep(Duration::from_secs(1)).await;
let connected_devices = core.services.device.get_connected_devices().await.unwrap();
if !connected_devices.is_empty() {
receiver_device_id = Some(connected_devices[0]);
println!("Alice: Bob connected! Device ID: {}", connected_devices[0]);
// Wait a bit longer to ensure session keys are properly established
println!("Alice: Allowing extra time for session key establishment...");
tokio::time::sleep(Duration::from_secs(2)).await;
break;
break connected_devices[0];
}
// Also check if there are any paired devices (even if not currently connected)
@@ -127,12 +144,11 @@ async fn alice_file_transfer_scenario() {
);
}
// Use the first paired device as the receiver
receiver_device_id = Some(paired_devices[0].device_id);
println!(
"Alice: Using paired device as receiver: {}",
paired_devices[0].device_id
);
break;
break paired_devices[0].device_id;
}
}
@@ -144,9 +160,7 @@ async fn alice_file_transfer_scenario() {
if attempts % 5 == 0 {
println!("Alice: Pairing status check {} - waiting", attempts / 5);
}
}
let receiver_id = receiver_device_id.unwrap();
};
// Create test files to transfer
println!("Alice: Creating test files for transfer...");
@@ -378,6 +392,30 @@ async fn bob_file_transfer_scenario() {
tokio::time::sleep(Duration::from_secs(3)).await;
println!("Bob: Networking initialized successfully");
// Create directory for received files BEFORE pairing (security requirement from PR #2944)
let received_dir = std::path::Path::new("/tmp/received_files");
std::fs::create_dir_all(received_dir).unwrap();
println!(
"Bob: Created directory for received files: {:?}",
received_dir
);
// Add allowed path for file transfers AFTER directory creation (security requirement from PR #2944)
if let Some(networking) = core.networking() {
let protocol_registry = networking.protocol_registry();
let registry = protocol_registry.read().await;
if let Some(handler) = registry.get_handler("file_transfer") {
if let Some(ft_handler) =
handler
.as_any()
.downcast_ref::<sd_core::service::network::protocol::FileTransferProtocolHandler>(
) {
ft_handler.add_allowed_path(std::path::PathBuf::from("/tmp/received_files"));
println!("Bob: Added /tmp/received_files as allowed path");
}
}
}
// Create a library for job dispatch (required for file transfers)
println!("Bob: Creating library for file transfer jobs...");
let _library = core
@@ -480,13 +518,7 @@ async fn bob_file_transfer_scenario() {
// Wait for file transfers
println!("Bob: Waiting for file transfers...");
// Create directory for received files
let received_dir = std::path::Path::new("/tmp/received_files");
std::fs::create_dir_all(received_dir).unwrap();
println!(
"Bob: Created directory for received files: {:?}",
received_dir
);
// Directory and allowed path already configured before pairing
// Wait for expected files to arrive
let expected_files = loop {