using System.Collections.Concurrent; using System.Diagnostics; using System.Security.Cryptography; using System.Text; using System.Text.RegularExpressions; using Cleanuparr.Domain.Enums; using Cleanuparr.Infrastructure.Helpers; using Cleanuparr.Persistence; using Cleanuparr.Persistence.Models.Configuration.MalwareBlocker; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Cleanuparr.Infrastructure.Features.MalwareBlocker; public sealed class BlocklistProvider : IBlocklistProvider { private readonly ILogger _logger; private readonly IServiceScopeFactory _scopeFactory; private readonly IMemoryCache _cache; private readonly Dictionary _configHashes = new(); private readonly Dictionary _lastLoadTimes = new(); private const int DefaultLoadIntervalHours = 4; private const int FastLoadIntervalMinutes = 5; private const string MalwareListUrl = "https://cleanuparr.pages.dev/static/known_malware_file_name_patterns"; private const string MalwareListKey = "MALWARE_PATTERNS"; public BlocklistProvider( ILogger logger, IServiceScopeFactory scopeFactory, IMemoryCache cache ) { _logger = logger; _scopeFactory = scopeFactory; _cache = cache; } public async Task LoadBlocklistsAsync() { try { await using var scope = _scopeFactory.CreateAsyncScope(); await using var dataContext = scope.ServiceProvider.GetRequiredService(); var fileReader = scope.ServiceProvider.GetRequiredService(); int changedCount = 0; var malwareBlockerConfig = await dataContext.ContentBlockerConfigs .AsNoTracking() .FirstAsync(); if (!malwareBlockerConfig.Enabled) { _logger.LogDebug("Malware Blocker is disabled, skipping blocklist loading"); return; } var instances = new Dictionary { { InstanceType.Sonarr, malwareBlockerConfig.Sonarr }, { InstanceType.Radarr, malwareBlockerConfig.Radarr }, { InstanceType.Lidarr, malwareBlockerConfig.Lidarr }, { InstanceType.Readarr, malwareBlockerConfig.Readarr }, { InstanceType.Whisparr, malwareBlockerConfig.Whisparr } }; foreach (var kv in instances) { if (await EnsureInstanceLoadedAsync(kv.Value, kv.Key, fileReader)) { changedCount++; } } // Always check and update malware patterns await LoadMalwarePatternsAsync(fileReader); if (changedCount > 0) { _logger.LogInformation("Successfully loaded {count} blocklists", changedCount); } else { _logger.LogTrace("All blocklists are already up to date"); } } catch (Exception ex) { _logger.LogError(ex, "Failed to load blocklists"); throw; } } public BlocklistType GetBlocklistType(InstanceType instanceType) { _cache.TryGetValue(CacheKeys.BlocklistType(instanceType), out BlocklistType? blocklistType); return blocklistType ?? BlocklistType.Blacklist; } public ConcurrentBag GetPatterns(InstanceType instanceType) { _cache.TryGetValue(CacheKeys.BlocklistPatterns(instanceType), out ConcurrentBag? patterns); return patterns ?? []; } public ConcurrentBag GetRegexes(InstanceType instanceType) { _cache.TryGetValue(CacheKeys.BlocklistRegexes(instanceType), out ConcurrentBag? regexes); return regexes ?? []; } public ConcurrentBag GetMalwarePatterns() { _cache.TryGetValue(CacheKeys.KnownMalwarePatterns(), out ConcurrentBag? patterns); return patterns ?? []; } private async Task EnsureInstanceLoadedAsync(BlocklistSettings settings, InstanceType instanceType, FileReader fileReader) { if (!settings.Enabled || string.IsNullOrEmpty(settings.BlocklistPath)) { return false; } string hash = GenerateSettingsHash(settings); var interval = GetLoadInterval(settings.BlocklistPath); var identifier = $"{instanceType}_{settings.BlocklistPath}"; if (ShouldReloadBlocklist(identifier, interval) || !_configHashes.TryGetValue(instanceType, out string? oldHash) || hash != oldHash) { _logger.LogDebug("Loading {instance} blocklist", instanceType); await LoadPatternsAndRegexesAsync(settings, instanceType, fileReader); _configHashes[instanceType] = hash; _lastLoadTimes[identifier] = DateTime.UtcNow; return true; } return false; } private TimeSpan GetLoadInterval(string? path) { if (!string.IsNullOrEmpty(path) && Uri.TryCreate(path, UriKind.Absolute, out var uri)) { if (uri.Host.Equals("cleanuparr.pages.dev", StringComparison.OrdinalIgnoreCase)) { return TimeSpan.FromMinutes(FastLoadIntervalMinutes); } return TimeSpan.FromHours(DefaultLoadIntervalHours); } // If fast load interval for local files return TimeSpan.FromMinutes(FastLoadIntervalMinutes); } private bool ShouldReloadBlocklist(string identifier, TimeSpan interval) { if (!_lastLoadTimes.TryGetValue(identifier, out DateTime lastLoad)) { return true; } return DateTime.UtcNow - lastLoad >= interval; } private async Task LoadMalwarePatternsAsync(FileReader fileReader) { var malwareInterval = TimeSpan.FromMinutes(FastLoadIntervalMinutes); if (!ShouldReloadBlocklist(MalwareListKey, malwareInterval)) { return; } try { _logger.LogDebug("Loading malware patterns"); string[] filePatterns = await fileReader.ReadContentAsync(MalwareListUrl); long startTime = Stopwatch.GetTimestamp(); ParallelOptions options = new() { MaxDegreeOfParallelism = 5 }; ConcurrentBag patterns = []; Parallel.ForEach(filePatterns, options, pattern => { patterns.Add(pattern); }); TimeSpan elapsed = Stopwatch.GetElapsedTime(startTime); _cache.Set(CacheKeys.KnownMalwarePatterns(), patterns); _lastLoadTimes[MalwareListKey] = DateTime.UtcNow; _logger.LogDebug("loaded {count} known malware patterns", patterns.Count); _logger.LogDebug("malware patterns loaded in {elapsed} ms", elapsed.TotalMilliseconds); } catch (Exception ex) { _logger.LogWarning(ex, "Failed to load malware patterns from {url}", MalwareListUrl); } } private async Task LoadPatternsAndRegexesAsync(BlocklistSettings blocklistSettings, InstanceType instanceType, FileReader fileReader) { if (string.IsNullOrEmpty(blocklistSettings.BlocklistPath)) { return; } string[] filePatterns = await fileReader.ReadContentAsync(blocklistSettings.BlocklistPath); long startTime = Stopwatch.GetTimestamp(); ParallelOptions options = new() { MaxDegreeOfParallelism = 5 }; const string regexId = "regex:"; ConcurrentBag patterns = []; ConcurrentBag regexes = []; Parallel.ForEach(filePatterns, options, pattern => { if (!pattern.StartsWith(regexId)) { patterns.Add(pattern); return; } pattern = pattern[regexId.Length..]; try { Regex regex = new(pattern, RegexOptions.Compiled); regexes.Add(regex); } catch (ArgumentException) { _logger.LogWarning("invalid regex | {pattern}", pattern); } }); TimeSpan elapsed = Stopwatch.GetElapsedTime(startTime); _cache.Set(CacheKeys.BlocklistType(instanceType), blocklistSettings.BlocklistType); _cache.Set(CacheKeys.BlocklistPatterns(instanceType), patterns); _cache.Set(CacheKeys.BlocklistRegexes(instanceType), regexes); _logger.LogDebug("loaded {count} patterns", patterns.Count); _logger.LogDebug("loaded {count} regexes", regexes.Count); _logger.LogDebug("blocklist loaded in {elapsed} ms | {path}", elapsed.TotalMilliseconds, blocklistSettings.BlocklistPath); } private string GenerateSettingsHash(BlocklistSettings blocklistSettings) { // Create a string that represents the relevant blocklist configuration var configStr = $"{blocklistSettings.Enabled}|{blocklistSettings.BlocklistPath ?? string.Empty}|{blocklistSettings.BlocklistType}"; // Create SHA256 hash of the configuration string using var sha = SHA256.Create(); var bytes = Encoding.UTF8.GetBytes(configStr); var hashBytes = sha.ComputeHash(bytes); return Convert.ToHexString(hashBytes).ToLowerInvariant(); } }