Files

269 lines
9.6 KiB
C#

using System.Collections.Concurrent;
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using System.Text.RegularExpressions;
using Cleanuparr.Domain.Enums;
using Cleanuparr.Infrastructure.Helpers;
using Cleanuparr.Persistence;
using Cleanuparr.Persistence.Models.Configuration.MalwareBlocker;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace Cleanuparr.Infrastructure.Features.MalwareBlocker;
public sealed class BlocklistProvider : IBlocklistProvider
{
private readonly ILogger<BlocklistProvider> _logger;
private readonly IServiceScopeFactory _scopeFactory;
private readonly IMemoryCache _cache;
private readonly Dictionary<InstanceType, string> _configHashes = new();
private readonly Dictionary<string, DateTime> _lastLoadTimes = new();
private const int DefaultLoadIntervalHours = 4;
private const int FastLoadIntervalMinutes = 5;
private const string MalwareListUrl = "https://cleanuparr.pages.dev/static/known_malware_file_name_patterns";
private const string MalwareListKey = "MALWARE_PATTERNS";
public BlocklistProvider(
ILogger<BlocklistProvider> logger,
IServiceScopeFactory scopeFactory,
IMemoryCache cache
)
{
_logger = logger;
_scopeFactory = scopeFactory;
_cache = cache;
}
public async Task LoadBlocklistsAsync()
{
try
{
await using var scope = _scopeFactory.CreateAsyncScope();
await using var dataContext = scope.ServiceProvider.GetRequiredService<DataContext>();
var fileReader = scope.ServiceProvider.GetRequiredService<FileReader>();
int changedCount = 0;
var malwareBlockerConfig = await dataContext.ContentBlockerConfigs
.AsNoTracking()
.FirstAsync();
if (!malwareBlockerConfig.Enabled)
{
_logger.LogDebug("Malware Blocker is disabled, skipping blocklist loading");
return;
}
var instances = new Dictionary<InstanceType, BlocklistSettings>
{
{ InstanceType.Sonarr, malwareBlockerConfig.Sonarr },
{ InstanceType.Radarr, malwareBlockerConfig.Radarr },
{ InstanceType.Lidarr, malwareBlockerConfig.Lidarr },
{ InstanceType.Readarr, malwareBlockerConfig.Readarr },
{ InstanceType.Whisparr, malwareBlockerConfig.Whisparr }
};
foreach (var kv in instances)
{
if (await EnsureInstanceLoadedAsync(kv.Value, kv.Key, fileReader))
{
changedCount++;
}
}
// Always check and update malware patterns
await LoadMalwarePatternsAsync(fileReader);
if (changedCount > 0)
{
_logger.LogInformation("Successfully loaded {count} blocklists", changedCount);
}
else
{
_logger.LogTrace("All blocklists are already up to date");
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load blocklists");
throw;
}
}
public BlocklistType GetBlocklistType(InstanceType instanceType)
{
_cache.TryGetValue(CacheKeys.BlocklistType(instanceType), out BlocklistType? blocklistType);
return blocklistType ?? BlocklistType.Blacklist;
}
public ConcurrentBag<string> GetPatterns(InstanceType instanceType)
{
_cache.TryGetValue(CacheKeys.BlocklistPatterns(instanceType), out ConcurrentBag<string>? patterns);
return patterns ?? [];
}
public ConcurrentBag<Regex> GetRegexes(InstanceType instanceType)
{
_cache.TryGetValue(CacheKeys.BlocklistRegexes(instanceType), out ConcurrentBag<Regex>? regexes);
return regexes ?? [];
}
public ConcurrentBag<string> GetMalwarePatterns()
{
_cache.TryGetValue(CacheKeys.KnownMalwarePatterns(), out ConcurrentBag<string>? patterns);
return patterns ?? [];
}
private async Task<bool> EnsureInstanceLoadedAsync(BlocklistSettings settings, InstanceType instanceType, FileReader fileReader)
{
if (!settings.Enabled || string.IsNullOrEmpty(settings.BlocklistPath))
{
return false;
}
string hash = GenerateSettingsHash(settings);
var interval = GetLoadInterval(settings.BlocklistPath);
var identifier = $"{instanceType}_{settings.BlocklistPath}";
if (ShouldReloadBlocklist(identifier, interval) || !_configHashes.TryGetValue(instanceType, out string? oldHash) || hash != oldHash)
{
_logger.LogDebug("Loading {instance} blocklist", instanceType);
await LoadPatternsAndRegexesAsync(settings, instanceType, fileReader);
_configHashes[instanceType] = hash;
_lastLoadTimes[identifier] = DateTime.UtcNow;
return true;
}
return false;
}
private TimeSpan GetLoadInterval(string? path)
{
if (!string.IsNullOrEmpty(path) && Uri.TryCreate(path, UriKind.Absolute, out var uri))
{
if (uri.Host.Equals("cleanuparr.pages.dev", StringComparison.OrdinalIgnoreCase))
{
return TimeSpan.FromMinutes(FastLoadIntervalMinutes);
}
return TimeSpan.FromHours(DefaultLoadIntervalHours);
}
// If fast load interval for local files
return TimeSpan.FromMinutes(FastLoadIntervalMinutes);
}
private bool ShouldReloadBlocklist(string identifier, TimeSpan interval)
{
if (!_lastLoadTimes.TryGetValue(identifier, out DateTime lastLoad))
{
return true;
}
return DateTime.UtcNow - lastLoad >= interval;
}
private async Task LoadMalwarePatternsAsync(FileReader fileReader)
{
var malwareInterval = TimeSpan.FromMinutes(FastLoadIntervalMinutes);
if (!ShouldReloadBlocklist(MalwareListKey, malwareInterval))
{
return;
}
try
{
_logger.LogDebug("Loading malware patterns");
string[] filePatterns = await fileReader.ReadContentAsync(MalwareListUrl);
long startTime = Stopwatch.GetTimestamp();
ParallelOptions options = new() { MaxDegreeOfParallelism = 5 };
ConcurrentBag<string> patterns = [];
Parallel.ForEach(filePatterns, options, pattern =>
{
patterns.Add(pattern);
});
TimeSpan elapsed = Stopwatch.GetElapsedTime(startTime);
_cache.Set(CacheKeys.KnownMalwarePatterns(), patterns);
_lastLoadTimes[MalwareListKey] = DateTime.UtcNow;
_logger.LogDebug("loaded {count} known malware patterns", patterns.Count);
_logger.LogDebug("malware patterns loaded in {elapsed} ms", elapsed.TotalMilliseconds);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to load malware patterns from {url}", MalwareListUrl);
}
}
private async Task LoadPatternsAndRegexesAsync(BlocklistSettings blocklistSettings, InstanceType instanceType, FileReader fileReader)
{
if (string.IsNullOrEmpty(blocklistSettings.BlocklistPath))
{
return;
}
string[] filePatterns = await fileReader.ReadContentAsync(blocklistSettings.BlocklistPath);
long startTime = Stopwatch.GetTimestamp();
ParallelOptions options = new() { MaxDegreeOfParallelism = 5 };
const string regexId = "regex:";
ConcurrentBag<string> patterns = [];
ConcurrentBag<Regex> regexes = [];
Parallel.ForEach(filePatterns, options, pattern =>
{
if (!pattern.StartsWith(regexId))
{
patterns.Add(pattern);
return;
}
pattern = pattern[regexId.Length..];
try
{
Regex regex = new(pattern, RegexOptions.Compiled);
regexes.Add(regex);
}
catch (ArgumentException)
{
_logger.LogWarning("invalid regex | {pattern}", pattern);
}
});
TimeSpan elapsed = Stopwatch.GetElapsedTime(startTime);
_cache.Set(CacheKeys.BlocklistType(instanceType), blocklistSettings.BlocklistType);
_cache.Set(CacheKeys.BlocklistPatterns(instanceType), patterns);
_cache.Set(CacheKeys.BlocklistRegexes(instanceType), regexes);
_logger.LogDebug("loaded {count} patterns", patterns.Count);
_logger.LogDebug("loaded {count} regexes", regexes.Count);
_logger.LogDebug("blocklist loaded in {elapsed} ms | {path}", elapsed.TotalMilliseconds, blocklistSettings.BlocklistPath);
}
private string GenerateSettingsHash(BlocklistSettings blocklistSettings)
{
// Create a string that represents the relevant blocklist configuration
var configStr = $"{blocklistSettings.Enabled}|{blocklistSettings.BlocklistPath ?? string.Empty}|{blocklistSettings.BlocklistType}";
// Create SHA256 hash of the configuration string
using var sha = SHA256.Create();
var bytes = Encoding.UTF8.GetBytes(configStr);
var hashBytes = sha.ComputeHash(bytes);
return Convert.ToHexString(hashBytes).ToLowerInvariant();
}
}