diff --git a/apps/server/AliasVault.Api/Controllers/Security/TwoFactorAuthController.cs b/apps/server/AliasVault.Api/Controllers/Security/TwoFactorAuthController.cs index 3bae548d7..2ccf38509 100644 --- a/apps/server/AliasVault.Api/Controllers/Security/TwoFactorAuthController.cs +++ b/apps/server/AliasVault.Api/Controllers/Security/TwoFactorAuthController.cs @@ -59,11 +59,22 @@ public class TwoFactorAuthController(IDbContextFactory dbC return Unauthorized(); } - var authenticatorKey = await GetUserManager().GetAuthenticatorKeyAsync(user); + string? authenticatorKey; + authenticatorKey = await GetUserManager().GetAuthenticatorKeyAsync(user); + + // Only reset (create new keys) if no key exists yet, avoiding duplicate key errors. if (string.IsNullOrEmpty(authenticatorKey)) { - await GetUserManager().ResetAuthenticatorKeyAsync(user); - authenticatorKey = await GetUserManager().GetAuthenticatorKeyAsync(user); + try + { + await GetUserManager().ResetAuthenticatorKeyAsync(user); + authenticatorKey = await GetUserManager().GetAuthenticatorKeyAsync(user); + } + catch (DbUpdateException) + { + // Key was most likely created by concurrent request, just get it. + authenticatorKey = await GetUserManager().GetAuthenticatorKeyAsync(user); + } } var encodedKey = urlEncoder.Encode(authenticatorKey!); @@ -90,14 +101,23 @@ public class TwoFactorAuthController(IDbContextFactory dbC if (isValid) { - await GetUserManager().SetTwoFactorEnabledAsync(user, true); + try + { + await GetUserManager().SetTwoFactorEnabledAsync(user, true); - // Generate new recovery codes. - var recoveryCodes = await GetUserManager().GenerateNewTwoFactorRecoveryCodesAsync(user, 10); + // Generate new recovery codes. + var recoveryCodes = await GetUserManager().GenerateNewTwoFactorRecoveryCodesAsync(user, 10); - await authLoggingService.LogAuthEventSuccessAsync(user.UserName!, AuthEventType.TwoFactorAuthEnable); + await authLoggingService.LogAuthEventSuccessAsync(user.UserName!, AuthEventType.TwoFactorAuthEnable); - return Ok(new { RecoveryCodes = recoveryCodes }); + return Ok(new { RecoveryCodes = recoveryCodes }); + } + catch (DbUpdateException) + { + // Likely a concurrent request already enabled 2FA, still return success. + var recoveryCodes = await GetUserManager().GenerateNewTwoFactorRecoveryCodesAsync(user, 10); + return Ok(new { RecoveryCodes = recoveryCodes }); + } } return BadRequest("Invalid code."); @@ -117,17 +137,28 @@ public class TwoFactorAuthController(IDbContextFactory dbC } await using var context = await dbContextFactory.CreateDbContextAsync(); + await using var transaction = await context.Database.BeginTransactionAsync(); - // Disable 2FA and remove any existing authenticator key(s) and recovery codes. - await GetUserManager().SetTwoFactorEnabledAsync(user, false); - context.UserTokens.RemoveRange( - await context.UserTokens.Where( - x => x.UserId == user.Id && - (x.Name == "AuthenticatorKey" || x.Name == "RecoveryCodes")).ToListAsync()); + try + { + // Disable 2FA and remove any existing authenticator key(s) and recovery codes. + await GetUserManager().SetTwoFactorEnabledAsync(user, false); - await context.SaveChangesAsync(); + context.UserTokens.RemoveRange( + await context.UserTokens.Where( + x => x.UserId == user.Id && + (x.Name == "AuthenticatorKey" || x.Name == "RecoveryCodes")).ToListAsync()); - await authLoggingService.LogAuthEventSuccessAsync(user.UserName!, AuthEventType.TwoFactorAuthDisable); - return Ok(); + await context.SaveChangesAsync(); + await transaction.CommitAsync(); + + await authLoggingService.LogAuthEventSuccessAsync(user.UserName!, AuthEventType.TwoFactorAuthDisable); + return Ok(); + } + catch + { + await transaction.RollbackAsync(); + throw; + } } }