using System.Net; using System.Security.Claims; using System.Text.Encodings.Web; using Cleanuparr.Infrastructure.Extensions; using Cleanuparr.Persistence; using Microsoft.AspNetCore.Authentication; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; namespace Cleanuparr.Api.Auth; public static class TrustedNetworkAuthenticationDefaults { public const string AuthenticationScheme = "TrustedNetwork"; } public class TrustedNetworkAuthenticationHandler : AuthenticationHandler { public TrustedNetworkAuthenticationHandler( IOptionsMonitor options, ILoggerFactory logger, UrlEncoder encoder) : base(options, logger, encoder) { } protected override async Task HandleAuthenticateAsync() { // Load auth config from database await using var dataContext = DataContext.CreateStaticInstance(); var config = await dataContext.GeneralConfigs.AsNoTracking().FirstOrDefaultAsync(); if (config is null || !config.Auth.DisableAuthForLocalAddresses) { return AuthenticateResult.NoResult(); } // Determine client IP var clientIp = GetClientIp(config.Auth.TrustForwardedHeaders); if (clientIp is null) { return AuthenticateResult.NoResult(); } // Check if the client IP is trusted if (!IsTrustedAddress(clientIp, config.Auth.TrustedNetworks)) { return AuthenticateResult.NoResult(); } // Load the admin user await using var usersContext = UsersContext.CreateStaticInstance(); var user = await usersContext.Users .AsNoTracking() .FirstOrDefaultAsync(u => u.SetupCompleted); if (user is null) { return AuthenticateResult.NoResult(); } var claims = new[] { new Claim(ClaimTypes.NameIdentifier, user.Id.ToString()), new Claim(ClaimTypes.Name, user.Username), new Claim("auth_method", "trusted_network") }; var identity = new ClaimsIdentity(claims, TrustedNetworkAuthenticationDefaults.AuthenticationScheme); var principal = new ClaimsPrincipal(identity); var ticket = new AuthenticationTicket(principal, TrustedNetworkAuthenticationDefaults.AuthenticationScheme); return AuthenticateResult.Success(ticket); } private IPAddress? GetClientIp(bool trustForwardedHeaders) => ResolveClientIp(Context, trustForwardedHeaders); public static IPAddress? ResolveClientIp(HttpContext httpContext, bool trustForwardedHeaders) { var remoteIp = httpContext.Connection.RemoteIpAddress; if (remoteIp is null) { return null; } // Only trust forwarded headers if the direct connection is from a local address if (trustForwardedHeaders && remoteIp.IsLocalAddress()) { // Check X-Forwarded-For first, then X-Real-IP var forwardedFor = httpContext.Request.Headers["X-Forwarded-For"].FirstOrDefault(); if (!string.IsNullOrEmpty(forwardedFor)) { // X-Forwarded-For can contain multiple IPs: client, proxy1, proxy2 // The first one is the original client var firstIp = forwardedFor.Split(',')[0].Trim(); if (IPAddress.TryParse(firstIp, out var parsedIp)) { return parsedIp; } } var realIp = httpContext.Request.Headers["X-Real-IP"].FirstOrDefault(); if (!string.IsNullOrEmpty(realIp) && IPAddress.TryParse(realIp, out var realParsedIp)) { return realParsedIp; } } return remoteIp; } public static bool IsTrustedAddress(IPAddress clientIp, List trustedNetworks) { // Normalize IPv4-mapped IPv6 addresses if (clientIp.IsIPv4MappedToIPv6) { clientIp = clientIp.MapToIPv4(); } // Check if it's a local address (built-in ranges) if (clientIp.IsLocalAddress()) { return true; } // Check against custom trusted networks foreach (var network in trustedNetworks) { if (MatchesCidr(clientIp, network)) { return true; } } return false; } public static bool MatchesCidr(IPAddress address, string cidr) { if (cidr.Contains('/')) { var parts = cidr.Split('/'); if (!IPAddress.TryParse(parts[0], out var networkAddress) || !int.TryParse(parts[1], out var prefixLength)) { return false; } // Normalize both addresses if (networkAddress.IsIPv4MappedToIPv6) networkAddress = networkAddress.MapToIPv4(); if (address.IsIPv4MappedToIPv6) address = address.MapToIPv4(); // Must be same address family if (address.AddressFamily != networkAddress.AddressFamily) return false; var addressBytes = address.GetAddressBytes(); var networkBytes = networkAddress.GetAddressBytes(); // Compare bytes up to prefix length var fullBytes = prefixLength / 8; var remainingBits = prefixLength % 8; for (var i = 0; i < fullBytes && i < addressBytes.Length; i++) { if (addressBytes[i] != networkBytes[i]) return false; } if (remainingBits > 0 && fullBytes < addressBytes.Length) { var mask = (byte)(0xFF << (8 - remainingBits)); if ((addressBytes[fullBytes] & mask) != (networkBytes[fullBytes] & mask)) return false; } return true; } // Plain IP match if (!IPAddress.TryParse(cidr, out var singleIp)) return false; if (singleIp.IsIPv4MappedToIPv6) singleIp = singleIp.MapToIPv4(); return address.Equals(singleIp); } }