Refactor admin so all tests pass (#190)

This commit is contained in:
Leendert de Borst
2024-12-23 12:16:05 +01:00
parent 22538ae000
commit 77a48ea4e9
16 changed files with 247 additions and 119 deletions

View File

@@ -8,11 +8,13 @@
protected override async Task OnInitializedAsync()
{
// Sign out the user.
// NOTE: the try/catch below is a workaround for the issue that the sign out does not work when
// NOTE: the try/catch below is a workaround for the issue that the sign-out does not work when
// the server session is already started.
try
{
await UserService.LoadCurrentUserAsync();
var username = UserService.User().UserName;
try
{
await SignInManager.SignOutAsync();
@@ -22,11 +24,12 @@
// Redirect to the home page with hard refresh.
NavigationService.RedirectTo("/", true);
}
catch
catch (Exception ex)
{
// Hard refresh current page if sign out fails. When an interactive server session is already started
// the sign out will fail because it tries to mutate cookies which is only possible when the server
// the sign-out will fail because it tries to mutate cookies which is only possible when the server
// session is not started yet.
Console.WriteLine(ex);
await AuthLoggingService.LogAuthEventSuccessAsync(username!, AuthEventType.Logout);
NavigationService.RedirectTo(NavigationService.Uri, true);
}

View File

@@ -171,7 +171,7 @@
try
{
InitInProgress = true;
var dbContext = await DbContextFactory.CreateDbContextAsync();
await using var dbContext = await DbContextFactory.CreateDbContextAsync();
ServiceStatus = await dbContext.WorkerServiceStatuses.ToListAsync();
foreach (var service in Services)
@@ -197,7 +197,7 @@
/// </summary>
private async Task<bool> UpdateServiceStatus(string serviceName, bool newStatus)
{
var dbContext = await DbContextFactory.CreateDbContextAsync();
await using var dbContext = await DbContextFactory.CreateDbContextAsync();
var entry = await dbContext.WorkerServiceStatuses.Where(x => x.ServiceName == serviceName).FirstOrDefaultAsync();
if (entry != null)
{
@@ -213,8 +213,8 @@
return false;
}
dbContext = await DbContextFactory.CreateDbContextAsync();
var check = await dbContext.WorkerServiceStatuses.Where(x => x.ServiceName == serviceName).FirstAsync();
await using var dbContextInner = await DbContextFactory.CreateDbContextAsync();
var check = await dbContextInner.WorkerServiceStatuses.Where(x => x.ServiceName == serviceName).FirstAsync();
if (check.CurrentStatus == newDesiredStatus)
{
return true;

View File

@@ -2,7 +2,6 @@
@using System.ComponentModel.DataAnnotations
@using Microsoft.AspNetCore.Identity
@inject UserManager<AdminUser> UserManager
@inject ILogger<ChangePassword> Logger
@@ -41,15 +40,13 @@
private async Task OnValidSubmitAsync()
{
var changePasswordResult = await UserManager.ChangePasswordAsync(UserService.User(), Input.OldPassword, Input.NewPassword);
var user = UserService.User();
user.LastPasswordChanged = DateTime.UtcNow;
await UserService.UpdateUserAsync(user);
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
// Clear the password fields
Input.OldPassword = "";
Input.NewPassword = "";
Input.ConfirmPassword = "";
var changePasswordResult = await UserManager.ChangePasswordAsync(user, Input.OldPassword, Input.NewPassword);
if (!changePasswordResult.Succeeded)
{
@@ -57,10 +54,15 @@
return;
}
user.LastPasswordChanged = DateTime.UtcNow;
await UserManager.UpdateAsync(user);
Input.OldPassword = "";
Input.NewPassword = "";
Input.ConfirmPassword = "";
Logger.LogInformation("User changed their password successfully.");
GlobalNotificationService.AddSuccessMessage("Your password has been changed.");
NavigationService.RedirectToCurrentPage();
}
@@ -82,5 +84,4 @@
[Compare("NewPassword", ErrorMessage = "The new password and confirmation password do not match.")]
public string ConfirmPassword { get; set; } = "";
}
}

View File

@@ -31,7 +31,13 @@
/// <inheritdoc />
protected override async Task OnInitializedAsync()
{
if (!await UserManager.GetTwoFactorEnabledAsync(UserService.User()))
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
if (!await UserManager.GetTwoFactorEnabledAsync(user))
{
throw new InvalidOperationException("Cannot disable 2FA for user as it's not currently enabled.");
}
@@ -39,7 +45,13 @@
private async Task OnSubmitAsync()
{
var disable2FaResult = await UserManager.SetTwoFactorEnabledAsync(UserService.User(), false);
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
var disable2FaResult = await UserManager.SetTwoFactorEnabledAsync(user, false);
if (!disable2FaResult.Succeeded)
{
await AuthLoggingService.LogAuthEventFailAsync(UserService.User().UserName!, AuthEventType.TwoFactorAuthDisable, AuthFailureReason.Unknown);

View File

@@ -13,6 +13,12 @@
<LayoutPageTitle>Configure authenticator app</LayoutPageTitle>
@if (_isLoading)
{
<LoadingIndicator />
return;
}
@if (RecoveryCodes is not null)
{
<ShowRecoveryCodes RecoveryCodes="RecoveryCodes.ToArray()"/>
@@ -69,15 +75,20 @@ else
private string? SharedKey { get; set; }
private string? AuthenticatorUri { get; set; }
private IEnumerable<string>? RecoveryCodes { get; set; }
private bool _isLoading = true;
[SupplyParameterFromForm] private InputModel Input { get; set; } = new();
/// <inheritdoc />
protected override async Task OnInitializedAsync()
/// <inheritdoc/>
protected override async Task OnAfterRenderAsync(bool firstRender)
{
await base.OnInitializedAsync();
await LoadSharedKeyAndQrCodeUriAsync(UserService.User());
await JsInvokeService.RetryInvokeAsync("generateQrCode", TimeSpan.Zero, 5, "authenticator-uri");
if (firstRender)
{
await LoadSharedKeyAndQrCodeUriAsync();
_isLoading = false;
StateHasChanged();
await JsInvokeService.RetryInvokeAsync("generateQrCode", TimeSpan.Zero, 5, "authenticator-uri");
}
}
private async Task OnValidSubmitAsync()
@@ -85,8 +96,13 @@ else
// Strip spaces and hyphens
var verificationCode = Input.Code.Replace(" ", string.Empty).Replace("-", string.Empty);
var is2FaTokenValid = await UserManager.VerifyTwoFactorTokenAsync(
UserService.User(), UserManager.Options.Tokens.AuthenticatorTokenProvider, verificationCode);
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
var is2FaTokenValid = await UserManager.VerifyTwoFactorTokenAsync(user, UserManager.Options.Tokens.AuthenticatorTokenProvider, verificationCode);
if (!is2FaTokenValid)
{
@@ -94,25 +110,31 @@ else
return;
}
await UserManager.SetTwoFactorEnabledAsync(UserService.User(), true);
await UserManager.SetTwoFactorEnabledAsync(user, true);
await AuthLoggingService.LogAuthEventSuccessAsync(UserService.User().UserName!, AuthEventType.TwoFactorAuthEnable);
Logger.LogInformation("User with ID '{UserId}' has enabled 2FA with an authenticator app.", UserService.User().Id);
GlobalNotificationService.AddSuccessMessage("Your authenticator app has been verified.");
if (await UserManager.CountRecoveryCodesAsync(UserService.User()) == 0)
if (await UserManager.CountRecoveryCodesAsync(user) == 0)
{
RecoveryCodes = await UserManager.GenerateNewTwoFactorRecoveryCodesAsync(UserService.User(), 10);
RecoveryCodes = await UserManager.GenerateNewTwoFactorRecoveryCodesAsync(user, 10);
}
else
{
// Navigate back to the two factor authentication page.
// Navigate back to the two-factor authentication page.
NavigationService.RedirectTo("account/manage/2fa", forceLoad: true);
}
}
private async ValueTask LoadSharedKeyAndQrCodeUriAsync(AdminUser user)
private async ValueTask LoadSharedKeyAndQrCodeUriAsync()
{
// Load the authenticator key & QR code URI to display on the form
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
// Load the authenticator key & QR code URI to display on the form.
var unformattedKey = await UserManager.GetAuthenticatorKeyAsync(user);
if (string.IsNullOrEmpty(unformattedKey))
{
@@ -126,7 +148,7 @@ else
AuthenticatorUri = GenerateQrCodeUri(username!, unformattedKey!);
}
private string FormatKey(string unformattedKey)
private static string FormatKey(string unformattedKey)
{
var result = new StringBuilder();
int currentPosition = 0;

View File

@@ -7,9 +7,9 @@
<LayoutPageTitle>Generate two-factor authentication (2FA) recovery codes</LayoutPageTitle>
@if (recoveryCodes is not null)
@if (_recoveryCodes is not null)
{
<ShowRecoveryCodes RecoveryCodes="recoveryCodes.ToArray()"/>
<ShowRecoveryCodes RecoveryCodes="_recoveryCodes.ToArray()"/>
}
else
{
@@ -35,14 +35,20 @@ else
}
@code {
private IEnumerable<string>? recoveryCodes;
private IEnumerable<string>? _recoveryCodes;
/// <inheritdoc />
protected override async Task OnInitializedAsync()
{
await base.OnInitializedAsync();
var isTwoFactorEnabled = await UserManager.GetTwoFactorEnabledAsync(UserService.User());
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
var isTwoFactorEnabled = await UserManager.GetTwoFactorEnabledAsync(user);
if (!isTwoFactorEnabled)
{
throw new InvalidOperationException("Cannot generate recovery codes for user because they do not have 2FA enabled.");
@@ -51,11 +57,16 @@ else
private async Task GenerateCodes()
{
var userId = await UserManager.GetUserIdAsync(UserService.User());
recoveryCodes = await UserManager.GenerateNewTwoFactorRecoveryCodesAsync(UserService.User(), 10);
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
_recoveryCodes = await UserManager.GenerateNewTwoFactorRecoveryCodesAsync(user, 10);
GlobalNotificationService.AddSuccessMessage("You have generated new recovery codes.");
Logger.LogInformation("User with ID '{UserId}' has generated new 2FA recovery codes.", userId);
Logger.LogInformation("User with ID '{UserId}' has generated new 2FA recovery codes.", UserService.User().Id);
}
}

View File

@@ -30,10 +30,15 @@
@code {
private async Task OnSubmitAsync()
{
await UserManager.SetTwoFactorEnabledAsync(UserService.User(), false);
await UserManager.ResetAuthenticatorKeyAsync(UserService.User());
var userId = await UserManager.GetUserIdAsync(UserService.User());
Logger.LogInformation("User with ID '{UserId}' has reset their authentication app key.", userId);
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
await UserManager.SetTwoFactorEnabledAsync(user, false);
await UserManager.ResetAuthenticatorKeyAsync(user);
Logger.LogInformation("User with ID '{UserId}' has reset their authentication app key.", UserService.User().Id);
GlobalNotificationService.AddSuccessMessage("Your authenticator app key has been reset, you will need to re-configure your authenticator app using the new key.");

View File

@@ -5,29 +5,29 @@
<LayoutPageTitle>Two-factor authentication (2FA)</LayoutPageTitle>
@if (is2FaEnabled)
@if (_is2FaEnabled)
{
<div class="p-4 bg-white border border-gray-200 rounded-lg shadow-sm dark:border-gray-700 sm:p-6 dark:bg-gray-800">
<h3 class="text-xl font-bold text-gray-900 dark:text-white mb-4">Two-factor authentication (2FA)</h3>
@if (recoveryCodesLeft == 0)
@if (_recoveryCodesLeft == 0)
{
<div class="mb-4 p-4 bg-red-100 border-l-4 border-red-500 text-red-700 dark:bg-red-900 dark:text-red-100">
<p class="font-bold">You have no recovery codes left.</p>
<p>You must <a href="account/manage/generate-recovery-codes" class="text-red-800 dark:text-red-200 underline">generate a new set of recovery codes</a> before you can log in with a recovery code.</p>
</div>
}
else if (recoveryCodesLeft == 1)
else if (_recoveryCodesLeft == 1)
{
<div class="mb-4 p-4 bg-red-100 border-l-4 border-red-500 text-red-700 dark:bg-red-900 dark:text-red-100">
<p class="font-bold">You have 1 recovery code left.</p>
<p>You can <a href="account/manage/generate-recovery-codes" class="text-red-800 dark:text-red-200 underline">generate a new set of recovery codes</a>.</p>
</div>
}
else if (recoveryCodesLeft <= 3)
else if (_recoveryCodesLeft <= 3)
{
<div class="mb-4 p-4 bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 dark:bg-yellow-900 dark:text-yellow-100">
<p class="font-bold">You have @recoveryCodesLeft recovery codes left.</p>
<p class="font-bold">You have @_recoveryCodesLeft recovery codes left.</p>
<p>You should <a href="account/manage/generate-recovery-codes" class="text-yellow-800 dark:text-yellow-200 underline">generate a new set of recovery codes</a>.</p>
</div>
}
@@ -42,7 +42,7 @@
<div class="p-4 bg-white border border-gray-200 rounded-lg shadow-sm dark:border-gray-700 sm:p-6 dark:bg-gray-800">
<h4 class="text-lg font-semibold text-gray-900 dark:text-white mb-4">Authenticator app</h4>
<div class="flex flex-col sm:flex-row space-y-2 sm:space-y-0 sm:space-x-2">
@if (!hasAuthenticator)
@if (!_hasAuthenticator)
{
<LinkButton Href="account/manage/enable-authenticator" Color="primary" Text="Add authenticator app" />
}
@@ -55,17 +55,23 @@
</div>
@code {
private bool hasAuthenticator;
private int recoveryCodesLeft;
private bool is2FaEnabled;
private bool _hasAuthenticator;
private int _recoveryCodesLeft;
private bool _is2FaEnabled;
/// <inheritdoc />
protected override async Task OnInitializedAsync()
{
await base.OnInitializedAsync();
hasAuthenticator = await UserManager.GetAuthenticatorKeyAsync(UserService.User()) is not null;
is2FaEnabled = await UserManager.GetTwoFactorEnabledAsync(UserService.User());
recoveryCodesLeft = await UserManager.CountRecoveryCodesAsync(UserService.User());
var user = await UserManager.FindByIdAsync(UserService.User().Id);
if (user == null)
{
throw new InvalidOperationException("User not found.");
}
_hasAuthenticator = await UserManager.GetAuthenticatorKeyAsync(user) is not null;
_is2FaEnabled = await UserManager.GetTwoFactorEnabledAsync(user);
_recoveryCodesLeft = await UserManager.CountRecoveryCodesAsync(user);
}
}

View File

@@ -84,15 +84,6 @@ public abstract class MainBase : OwningComponentBase
/// </summary>
protected List<BreadcrumbItem> BreadcrumbItems { get; } = new();
/// <summary>
/// Gets the AliasServerDbContext instance asynchronously.
/// </summary>
/// <returns>The AliasServerDbContext instance.</returns>
protected async Task<AliasServerDbContext> GetDbContextAsync()
{
return await DbContextFactory.CreateDbContextAsync();
}
/// <inheritdoc />
protected override async Task OnInitializedAsync()
{

View File

@@ -64,7 +64,7 @@
/// </summary>
public async Task RefreshData()
{
var dbContext = await DbContextFactory.CreateDbContextAsync();
await using var dbContext = await DbContextFactory.CreateDbContextAsync();
var query = dbContext.TaskRunnerJobs.AsQueryable();
// Apply sorting

View File

@@ -122,7 +122,7 @@
{
try
{
var dbContext = await DbContextFactory.CreateDbContextAsync();
await using var dbContext = await DbContextFactory.CreateDbContextAsync();
var job = new TaskRunnerJob
{
Name = nameof(TaskRunnerJobType.Maintenance),

View File

@@ -20,7 +20,6 @@ using Microsoft.EntityFrameworkCore;
/// <param name="httpContextAccessor">HttpContextManager instance.</param>
public class UserService(IAliasServerDbContextFactory dbContextFactory, UserManager<AdminUser> userManager, IHttpContextAccessor httpContextAccessor)
{
private const string AdminRole = "Admin";
private AdminUser? _user;
/// <summary>
@@ -28,11 +27,6 @@ public class UserService(IAliasServerDbContextFactory dbContextFactory, UserMana
/// </summary>
public event Action OnChange = () => { };
/// <summary>
/// Gets a value indicating whether the User is loaded and available, false if not. Use this before accessing User() method.
/// </summary>
public bool UserLoaded => _user != null;
/// <summary>
/// Returns all users.
/// </summary>
@@ -85,7 +79,7 @@ public class UserService(IAliasServerDbContextFactory dbContextFactory, UserMana
// Load user from database. Use a new context everytime to ensure we get the latest data.
var userName = httpContextAccessor.HttpContext?.User.Identity?.Name ?? string.Empty;
var dbContext = await dbContextFactory.CreateDbContextAsync();
await using var dbContext = await dbContextFactory.CreateDbContextAsync();
var user = await dbContext.AdminUsers.FirstOrDefaultAsync(u => u.UserName == userName);
if (user != null)
{

View File

@@ -15,8 +15,10 @@ using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Mvc.Testing;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Npgsql;
/// <summary>
/// Admin web application factory fixture for integration tests.
@@ -25,15 +27,10 @@ using Microsoft.Extensions.Hosting;
public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFactory<TEntryPoint>
where TEntryPoint : class
{
/// <summary>
/// The DbConnection instance that is created for the test.
/// </summary>
private DbConnection _dbConnection;
/// <summary>
/// The DbContextFactory instance that is created for the test.
/// </summary>
private IDbContextFactory<AliasServerDbContext> _dbContextFactory = null!;
private IAliasServerDbContextFactory _dbContextFactory = null!;
/// <summary>
/// The cached DbContext instance that can be used during the test.
@@ -41,13 +38,9 @@ public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFact
private AliasServerDbContext? _dbContext;
/// <summary>
/// Initializes a new instance of the <see cref="WebApplicationAdminFactoryFixture{TEntryPoint}"/> class.
/// The name of the temporary test database.
/// </summary>
public WebApplicationAdminFactoryFixture()
{
_dbConnection = new SqliteConnection("DataSource=:memory:");
_dbConnection.Open();
}
private string? _tempDbName;
/// <summary>
/// Gets or sets the port the web application kestrel host will listen on.
@@ -70,14 +63,46 @@ public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFact
}
/// <summary>
/// Disposes the DbConnection instance.
/// Disposes the DbConnection instance and drops the temporary database.
/// </summary>
/// <returns>ValueTask.</returns>
public override ValueTask DisposeAsync()
/// <returns>Task.</returns>
public override async ValueTask DisposeAsync()
{
_dbConnection.Dispose();
if (_dbContext != null)
{
await _dbContext.DisposeAsync();
_dbContext = null;
}
if (!string.IsNullOrEmpty(_tempDbName))
{
// Create a connection to 'postgres' database to drop the test database
using var conn = new NpgsqlConnection("Host=localhost;Port=5432;Database=postgres;Username=aliasvault;Password=password");
await conn.OpenAsync();
// First terminate existing connections
using (var cmd = conn.CreateCommand())
{
cmd.CommandText = $"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '{_tempDbName}';
""";
await cmd.ExecuteNonQueryAsync();
}
// Then drop the database in a separate command
using (var cmd = conn.CreateCommand())
{
cmd.CommandText = $"""
DROP DATABASE IF EXISTS "{_tempDbName}";
""";
await cmd.ExecuteNonQueryAsync();
}
}
GC.SuppressFinalize(this);
return base.DisposeAsync();
await base.DisposeAsync();
}
/// <inheritdoc />
@@ -92,7 +117,7 @@ public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFact
var host = base.CreateHost(builder);
// Get the DbContextFactory instance and store it for later use during tests.
_dbContextFactory = host.Services.GetRequiredService<IDbContextFactory<AliasServerDbContext>>();
_dbContextFactory = host.Services.GetRequiredService<IAliasServerDbContextFactory>();
return host;
}
@@ -102,6 +127,20 @@ public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFact
{
SetEnvironmentVariables();
builder.ConfigureAppConfiguration((context, configBuilder) =>
{
configBuilder.Sources.Clear();
_tempDbName = $"aliasdb_test_{Guid.NewGuid()}";
configBuilder.AddJsonFile("appsettings.json", optional: true);
configBuilder.AddInMemoryCollection(new Dictionary<string, string?>
{
["DatabaseProvider"] = "postgresql",
["ConnectionStrings:AliasServerDbContext"] = $"Host=localhost;Port=5432;Database={_tempDbName};Username=aliasvault;Password=password",
});
});
builder.ConfigureServices(services =>
{
RemoveExistingRegistrations(services);
@@ -126,7 +165,6 @@ public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFact
private static void RemoveExistingRegistrations(IServiceCollection services)
{
var descriptorsToRemove = services.Where(d =>
d.ServiceType.ToString().Contains("AliasServerDbContext") ||
d.ServiceType == typeof(VersionedContentService)).ToList();
foreach (var descriptor in descriptorsToRemove)
@@ -142,10 +180,10 @@ public class WebApplicationAdminFactoryFixture<TEntryPoint> : WebApplicationFact
private void AddNewRegistrations(IServiceCollection services)
{
// Add the DbContextFactory
services.AddDbContextFactory<AliasServerDbContext>(options =>
/*services.AddDbContextFactory<AliasServerDbContext>(options =>
{
options.UseSqlite(_dbConnection).UseLazyLoadingProxies();
});
});*/
// Add the VersionedContentService
services.AddSingleton(new VersionedContentService("../../../../../AliasVault.Admin/wwwroot"));

View File

@@ -15,8 +15,10 @@ using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Mvc.Testing;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Npgsql;
/// <summary>
/// API web application factory fixture for integration tests.
@@ -25,15 +27,10 @@ using Microsoft.Extensions.Hosting;
public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactory<TEntryPoint>
where TEntryPoint : class
{
/// <summary>
/// The DbConnection instance that is created for the test.
/// </summary>
private DbConnection _dbConnection;
/// <summary>
/// The DbContextFactory instance that is created for the test.
/// </summary>
private IDbContextFactory<AliasServerDbContext> _dbContextFactory = null!;
private IAliasServerDbContextFactory _dbContextFactory = null!;
/// <summary>
/// The cached DbContext instance that can be used during the test.
@@ -41,13 +38,9 @@ public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactor
private AliasServerDbContext? _dbContext;
/// <summary>
/// Initializes a new instance of the <see cref="WebApplicationApiFactoryFixture{TEntryPoint}"/> class.
/// The name of the temporary test database.
/// </summary>
public WebApplicationApiFactoryFixture()
{
_dbConnection = new SqliteConnection("DataSource=:memory:");
_dbConnection.Open();
}
private string? _tempDbName;
/// <summary>
/// Gets or sets the port the web application kestrel host will listen on.
@@ -75,14 +68,46 @@ public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactor
}
/// <summary>
/// Disposes the DbConnection instance.
/// Disposes the DbConnection instance and drops the temporary database.
/// </summary>
/// <returns>ValueTask.</returns>
public override ValueTask DisposeAsync()
/// <returns>Task.</returns>
public override async ValueTask DisposeAsync()
{
_dbConnection.Dispose();
if (_dbContext != null)
{
await _dbContext.DisposeAsync();
_dbContext = null;
}
if (!string.IsNullOrEmpty(_tempDbName))
{
// Create a connection to 'postgres' database to drop the test database
using var conn = new NpgsqlConnection("Host=localhost;Port=5432;Database=postgres;Username=aliasvault;Password=password");
await conn.OpenAsync();
// First terminate existing connections
using (var cmd = conn.CreateCommand())
{
cmd.CommandText = $"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '{_tempDbName}';
""";
await cmd.ExecuteNonQueryAsync();
}
// Then drop the database in a separate command
using (var cmd = conn.CreateCommand())
{
cmd.CommandText = $"""
DROP DATABASE IF EXISTS "{_tempDbName}";
""";
await cmd.ExecuteNonQueryAsync();
}
}
GC.SuppressFinalize(this);
return base.DisposeAsync();
await base.DisposeAsync();
}
/// <inheritdoc />
@@ -97,7 +122,7 @@ public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactor
var host = base.CreateHost(builder);
// Get the DbContextFactory instance and store it for later use during tests.
_dbContextFactory = host.Services.GetRequiredService<IDbContextFactory<AliasServerDbContext>>();
_dbContextFactory = host.Services.GetRequiredService<IAliasServerDbContextFactory>();
return host;
}
@@ -107,6 +132,20 @@ public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactor
{
SetEnvironmentVariables();
builder.ConfigureAppConfiguration((context, configBuilder) =>
{
configBuilder.Sources.Clear();
_tempDbName = $"aliasdb_test_{Guid.NewGuid()}";
configBuilder.AddJsonFile("appsettings.json", optional: true);
configBuilder.AddInMemoryCollection(new Dictionary<string, string?>
{
["DatabaseProvider"] = "postgresql",
["ConnectionStrings:AliasServerDbContext"] = $"Host=localhost;Port=5432;Database={_tempDbName};Username=aliasvault;Password=password",
});
});
builder.ConfigureServices(services =>
{
RemoveExistingRegistrations(services);
@@ -131,7 +170,6 @@ public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactor
private static void RemoveExistingRegistrations(IServiceCollection services)
{
var descriptorsToRemove = services.Where(d =>
d.ServiceType.ToString().Contains("AliasServerDbContext") ||
d.ServiceType == typeof(ITimeProvider)).ToList();
foreach (var descriptor in descriptorsToRemove)
@@ -147,10 +185,11 @@ public class WebApplicationApiFactoryFixture<TEntryPoint> : WebApplicationFactor
private void AddNewRegistrations(IServiceCollection services)
{
// Add the DbContextFactory
services.AddDbContextFactory<AliasServerDbContext>(options =>
/*services.AddDbContextFactory<AliasServerDbContext>(options =>
{
options.UseSqlite(_dbConnection).UseLazyLoadingProxies();
});
});*/
// services.AddSingleton<IAliasServerDbContextFactory, SqliteDbContextFactory>();
// Add TestTimeProvider
services.AddSingleton<ITimeProvider>(TimeProvider);

View File

@@ -38,8 +38,8 @@ public class ServerSettingsTests : AdminPlaywrightTest
await Page.Locator("input[id='schedule']").FillAsync("03:30");
// Uncheck Sunday and Saturday from maintenance days
await Page.Locator("input[id='day_7']").UncheckAsync(); // Sunday
await Page.Locator("input[id='day_6']").UncheckAsync(); // Saturday
await Page.Locator("input[id='day_7']").UncheckAsync(); // Sunday
// Save changes
var saveButton = Page.Locator("text=Save changes");
@@ -75,6 +75,9 @@ public class ServerSettingsTests : AdminPlaywrightTest
await Page.ReloadAsync();
await WaitForUrlAsync("settings/server", "Server settings");
// Wait for 0.5sec to ensure the page is fully loaded.
await Task.Delay(500);
var generalLogRetentionValue = await Page.Locator("input[id='generalLogRetention']").InputValueAsync();
Assert.That(generalLogRetentionValue, Is.EqualTo("45"), "General log retention value not persisted after refresh");

View File

@@ -29,6 +29,9 @@ public class TwoFactorAuthLockoutTests : AdminPlaywrightTest
var enable2FaButton = Page.GetByRole(AriaRole.Link, new() { Name = "Add authenticator app" });
await enable2FaButton.ClickAsync();
// Wait for QR code to appear.
await WaitForUrlAsync("account/manage/enable-authenticator", "Scan the QR Code or enter this key");
// Extract secret key from page.
var secretKey = await Page.TextContentAsync("kbd");