Refactor and cleanup (#80)

This commit is contained in:
Leendert de Borst
2024-08-30 21:36:16 +02:00
parent 95949508ba
commit 072e63e98f
10 changed files with 39 additions and 67 deletions

View File

@@ -303,7 +303,7 @@ public class AuthController(IDbContextFactory<AliasServerDbContext> dbContextFac
public async Task<IActionResult> Register([FromBody] SrpSignup model)
{
// Validate username, disallow "admin" as a username.
if (model.Username.ToLower() == "admin")
if (string.Equals(model.Username, "admin", StringComparison.OrdinalIgnoreCase))
{
return BadRequest(ServerValidationErrorResponse.Create(["Username 'admin' is not allowed."], 400));
}

View File

@@ -39,7 +39,7 @@ public class TwoFactorAuthController(IDbContextFactory<AliasServerDbContext> dbC
var user = await GetCurrentUserAsync();
if (user is null)
{
return Unauthorized("Not authenticated.");
return Unauthorized();
}
var twoFactorEnabled = await GetUserManager().GetTwoFactorEnabledAsync(user);
@@ -56,7 +56,7 @@ public class TwoFactorAuthController(IDbContextFactory<AliasServerDbContext> dbC
var user = await GetCurrentUserAsync();
if (user is null)
{
return Unauthorized("Not authenticated.");
return Unauthorized();
}
var authenticatorKey = await GetUserManager().GetAuthenticatorKeyAsync(user);
@@ -83,7 +83,7 @@ public class TwoFactorAuthController(IDbContextFactory<AliasServerDbContext> dbC
var user = await GetCurrentUserAsync();
if (user is null)
{
return Unauthorized("Not authenticated.");
return Unauthorized();
}
var isValid = await GetUserManager().VerifyTwoFactorTokenAsync(user, GetUserManager().Options.Tokens.AuthenticatorTokenProvider, code);
@@ -113,7 +113,7 @@ public class TwoFactorAuthController(IDbContextFactory<AliasServerDbContext> dbC
var user = await GetCurrentUserAsync();
if (user is null)
{
return Unauthorized("Not authenticated.");
return Unauthorized();
}
await using var context = await dbContextFactory.CreateDbContextAsync();

View File

@@ -46,6 +46,6 @@ public class TestController(UserManager<AliasVaultUser> userManager) : Authentic
public IActionResult TestCallError()
{
// Throw an exception here to test error handling.
throw new ApplicationException("Test error");
throw new ArgumentException("Test error");
}
}

View File

@@ -83,30 +83,4 @@
IsLoading = false;
StateHasChanged();
}
/// <summary>
/// Revokes a specific session (refresh token) for the current user.
/// </summary>
/// <param name="id">The unique identifier of the session to revoke.</param>
/// <returns>A task representing the asynchronous operation.</returns>
private async Task RevokeSession(Guid id)
{
try
{
var response = await Http.DeleteAsync($"api/v1/Security/sessions/{id}");
if (response.IsSuccessStatusCode)
{
GlobalNotificationService.AddSuccessMessage("Session revoked successfully.", true);
await OnSessionsChanged.InvokeAsync();
}
else
{
GlobalNotificationService.AddErrorMessage("Failed to revoke session.", true);
}
}
catch (Exception ex)
{
GlobalNotificationService.AddErrorMessage($"Failed to revoke session: {ex.Message}.", true);
}
}
}

View File

@@ -51,7 +51,7 @@ public static class DatabaseConfiguration
return services;
}
private static DbConnection CreateAndConfigureSqliteConnection(string connectionString)
private static SqliteConnection CreateAndConfigureSqliteConnection(string connectionString)
{
var connection = new SqliteConnection(connectionString);
connection.Open();

View File

@@ -1,4 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
// <auto-generated />
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable

View File

@@ -1,4 +1,5 @@
using Microsoft.EntityFrameworkCore.Migrations;
// <auto-generated />
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable

View File

@@ -1,4 +1,5 @@
using System;
// <auto-generated />
using System;
using Microsoft.EntityFrameworkCore.Migrations;
#nullable disable

View File

@@ -108,6 +108,8 @@ public class AuthTests : ClientPlaywrightTest
// Wait for account lockout message.
await WaitForUrlAsync("user/login**", "locked out");
var pageContent = await Page.TextContentAsync("body");
Assert.That(pageContent, Does.Contain("locked out"), "No account lockout message.");
}
/// <summary>

View File

@@ -44,11 +44,11 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
DeviceType = DetermineDeviceType(httpContext),
OperatingSystem = DetermineOperatingSystem(httpContext),
Browser = DetermineBrowser(httpContext),
Country = DetermineCountry(httpContext),
IsSuspiciousActivity = DetermineSuspiciousActivity(),
Country = DetermineCountry(),
IsSuspiciousActivity = false,
};
dbContext.AuthLogs.Add(authAttempt);
await dbContext.AuthLogs.AddAsync(authAttempt);
await dbContext.SaveChangesAsync();
}
@@ -78,11 +78,11 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
DeviceType = DetermineDeviceType(httpContext),
OperatingSystem = DetermineOperatingSystem(httpContext),
Browser = DetermineBrowser(httpContext),
Country = DetermineCountry(httpContext),
IsSuspiciousActivity = DetermineSuspiciousActivity(),
Country = DetermineCountry(),
IsSuspiciousActivity = false,
};
dbContext.AuthLogs.Add(authAttempt);
await dbContext.AuthLogs.AddAsync(authAttempt);
await dbContext.SaveChangesAsync();
}
@@ -91,11 +91,14 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
/// </summary>
/// <param name="context">The HttpContext containing the request information.</param>
/// <returns>A string representing the device type: "Mobile", "Tablet", "Smart TV", "Desktop", or "Unknown".</returns>
private string DetermineDeviceType(HttpContext? context)
private static string? DetermineDeviceType(HttpContext? context)
{
if (context == null) return "Unknown";
if (context is null)
{
return null;
}
return context.Request.Headers["User-Agent"].ToString().ToLower() switch
return context.Request.Headers.UserAgent.ToString().ToLower() switch
{
var ua when ua.Contains("mobile") || ua.Contains("android") || ua.Contains("iphone") => "Mobile",
var ua when ua.Contains("tablet") || ua.Contains("ipad") => "Tablet",
@@ -109,11 +112,11 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
/// </summary>
/// <param name="context">The HttpContext containing the request information.</param>
/// <returns>A string representing the operating system: "Windows", "MacOS", "Linux", "Android", "iOS", or "Unknown".</returns>
private string DetermineOperatingSystem(HttpContext? context)
private static string? DetermineOperatingSystem(HttpContext? context)
{
if (context is null)
{
return "Unknown";
return null;
}
return context.Request.Headers.UserAgent.ToString().ToLower() switch
@@ -123,7 +126,7 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
var ua when ua.Contains("linux") => "Linux",
var ua when ua.Contains("android") => "Android",
var ua when ua.Contains("iphone") || ua.Contains("ipad") => "iOS",
_ => "Unknown",
_ => null,
};
}
@@ -132,11 +135,11 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
/// </summary>
/// <param name="context">The HttpContext containing the request information.</param>
/// <returns>A string representing the browser: "Firefox", "Chrome", "Safari", "Edge", "Opera", or "Unknown".</returns>
private string DetermineBrowser(HttpContext? context)
private static string? DetermineBrowser(HttpContext? context)
{
if (context is null)
{
return "Unknown";
return null;
}
return context.Request.Headers.UserAgent.ToString().ToLower() switch
@@ -146,38 +149,29 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
var ua when ua.Contains("safari") && !ua.Contains("chrome") => "Safari",
var ua when ua.Contains("edg") => "Edge",
var ua when ua.Contains("opr") || ua.Contains("opera") => "Opera",
_ => "Unknown"
_ => null
};
}
/// <summary>
/// Determines the country based on the IP address of the request.
/// </summary>
/// <param name="context">The HttpContext containing the request information.</param>
/// <returns>A string representing the country or "Unknown" if the country cannot be determined.</returns>
/// <remarks>
/// This method currently returns a placeholder value. In a production environment,
/// it should be implemented using a Geo-IP database or service for accurate results.
/// This method currently returns null as the implementation is not yet complete.
/// </remarks>
private string DetermineCountry(HttpContext? context)
private static string? DetermineCountry()
{
// Implement later by using a Geo-IP database or service.
return "Unknown";
return null;
}
/// <summary>
/// Logic to determine if the activity is suspicious. For the moment it always returns false.
/// Needs to be implemented later.
/// </summary>
/// <returns></returns>
private bool DetermineSuspiciousActivity() => false;
/// <summary>
/// Extract IP address from HttpContext.
/// </summary>
/// <param name="httpContext">HttpContext to extract the IP address from.</param>
/// <returns></returns>
private string GetIpFromContext(HttpContext? httpContext)
private static string GetIpFromContext(HttpContext? httpContext)
{
string ipAddress = "";
@@ -189,13 +183,12 @@ public class AuthLoggingService(IServiceProvider serviceProvider, IHttpContextAc
if (string.IsNullOrEmpty(ipAddress))
{
// Check if X-Forwarded-For header exists, if so, extract first IP address from comma separated list.
if (httpContext.Request.Headers.ContainsKey("X-Forwarded-For"))
if (httpContext.Request.Headers.TryGetValue("X-Forwarded-For", out var xForwardedFor))
{
ipAddress = httpContext.Request.Headers["X-Forwarded-For"].ToString().Split(',')[0];
ipAddress = xForwardedFor.ToString().Split(',')[0];
}
else
{
// Otherwise use RemoteIpAddress.
ipAddress = httpContext.Connection.RemoteIpAddress?.ToString() ?? "0.0.0.0";
}
}