//-----------------------------------------------------------------------
//
// Copyright (c) lanedirt. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
//-----------------------------------------------------------------------
namespace AliasVault.Api.Controllers;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Security.Cryptography;
using System.Text;
using AliasDb;
using AliasVault.Shared.Models;
using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Mvc;
using Microsoft.IdentityModel.Tokens;
///
/// Auth controller for handling authentication.
///
/// AliasDbContext instance.
/// UserManager instance.
/// SignInManager instance.
/// IConfiguration instance.
[Route("api/[controller]")]
[ApiController]
public class AuthController(AliasDbContext context, UserManager userManager, SignInManager signInManager, IConfiguration configuration) : ControllerBase
{
///
/// Login endpoint used to process login attempt using credentials.
///
/// Login model.
/// IActionResult.
[HttpPost("login")]
public async Task Login([FromBody] LoginModel model)
{
var user = await userManager.FindByEmailAsync(model.Email);
if (user != null && await userManager.CheckPasswordAsync(user, model.Password))
{
var tokenModel = await GenerateNewTokenForUser(user);
return Ok(tokenModel);
}
return Unauthorized();
}
///
/// Refresh endpoint used to refresh an expired access token using a valid refresh token.
///
/// Token model.
/// IActionResult.
[HttpPost("refresh")]
public async Task Refresh([FromBody] TokenModel tokenModel)
{
var principal = GetPrincipalFromExpiredToken(tokenModel.Token);
if (principal.FindFirst(ClaimTypes.NameIdentifier)?.Value == null)
{
return Unauthorized("User not found (email-1)");
}
var user = await userManager.FindByIdAsync(principal.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? string.Empty);
if (user == null)
{
return Unauthorized("User not found (email-2)");
}
// Check if the refresh token is valid.
// Remove any existing refresh tokens for this user and device.
var deviceIdentifier = GenerateDeviceIdentifier(Request);
var existingToken = context.AspNetUserRefreshTokens.Where(t => t.UserId == user.Id && t.DeviceIdentifier == deviceIdentifier).FirstOrDefault();
if (existingToken == null || existingToken.Value != tokenModel.RefreshToken || existingToken.ExpireDate < DateTime.Now)
{
return Unauthorized("Refresh token expired");
}
// Remove the existing refresh token.
context.AspNetUserRefreshTokens.Remove(existingToken);
// Generate a new refresh token to replace the old one.
var newRefreshToken = GenerateRefreshToken();
// Add new refresh token.
await context.AspNetUserRefreshTokens.AddAsync(new AspNetUserRefreshToken
{
UserId = user.Id,
DeviceIdentifier = deviceIdentifier,
Value = newRefreshToken,
ExpireDate = DateTime.Now.AddDays(30),
CreatedAt = DateTime.Now,
});
await context.SaveChangesAsync();
var token = GenerateJwtToken(user);
return Ok(new TokenModel() { Token = token, RefreshToken = newRefreshToken });
}
///
/// Revoke endpoint used to revoke a refresh token.
///
/// Token model.
/// IActionResult.
[HttpPost("revoke")]
public async Task Revoke([FromBody] TokenModel model)
{
var principal = GetPrincipalFromExpiredToken(model.Token);
if (principal.FindFirst(ClaimTypes.NameIdentifier)?.Value == null)
{
return Unauthorized("User not found (email-1)");
}
var user = await userManager.FindByIdAsync(principal.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? string.Empty);
if (user == null)
{
return Unauthorized("User not found (email-2)");
}
// Check if the refresh token is valid.
var deviceIdentifier = GenerateDeviceIdentifier(Request);
var existingToken = context.AspNetUserRefreshTokens.Where(t => t.UserId == user.Id && t.DeviceIdentifier == deviceIdentifier).FirstOrDefault();
if (existingToken == null || existingToken.Value != model.RefreshToken)
{
return Unauthorized("Invalid refresh token");
}
// Remove the existing refresh token.
context.AspNetUserRefreshTokens.Remove(existingToken);
await context.SaveChangesAsync();
return Ok("Refresh token revoked successfully");
}
///
/// Register endpoint used to register a new user.
///
/// Register model.
/// IActionResult.
[HttpPost("register")]
public async Task Register([FromBody] RegisterModel model)
{
var user = new IdentityUser { UserName = model.Email, Email = model.Email };
var result = await userManager.CreateAsync(user, model.Password);
if (result.Succeeded)
{
// When a user is registered, they are automatically signed in.
await signInManager.SignInAsync(user, isPersistent: false);
// Return the token.
var tokenModel = await GenerateNewTokenForUser(user);
return Ok(tokenModel);
}
else
{
return BadRequest(result.Errors);
}
}
private string GenerateJwtToken(IdentityUser user)
{
var claims = new List
{
new(ClaimTypes.NameIdentifier, user.Id ?? string.Empty),
new(ClaimTypes.Name, user.UserName ?? string.Empty),
new(ClaimTypes.Email, user.Email ?? string.Empty),
new(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()),
};
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(configuration["Jwt:Key"] ?? string.Empty));
var creds = new SigningCredentials(key, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken(
issuer: configuration["Jwt:Issuer"] ?? string.Empty,
audience: configuration["Jwt:Issuer"] ?? string.Empty,
claims: claims,
expires: DateTime.Now.AddMinutes(30),
signingCredentials: creds);
return new JwtSecurityTokenHandler().WriteToken(token);
}
private string GenerateRefreshToken()
{
var randomNumber = new byte[32];
using var rng = RandomNumberGenerator.Create();
rng.GetBytes(randomNumber);
return Convert.ToBase64String(randomNumber);
}
private ClaimsPrincipal GetPrincipalFromExpiredToken(string token)
{
var tokenValidationParameters = new TokenValidationParameters
{
ValidateAudience = false,
ValidateIssuer = false,
ValidateIssuerSigningKey = true,
IssuerSigningKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(configuration["Jwt:Key"] ?? string.Empty)),
ValidateLifetime = false,
};
var tokenHandler = new JwtSecurityTokenHandler();
var principal = tokenHandler.ValidateToken(token, tokenValidationParameters, out SecurityToken securityToken);
if (securityToken is not JwtSecurityToken jwtSecurityToken || !jwtSecurityToken.Header.Alg.Equals(SecurityAlgorithms.HmacSha256, StringComparison.InvariantCultureIgnoreCase))
{
throw new SecurityTokenException("Invalid token");
}
return principal;
}
private string GenerateDeviceIdentifier(HttpRequest request)
{
// TODO: Add more headers to the device identifier or let client send a unique identifier instead.
var userAgent = request.Headers.UserAgent.ToString();
var acceptLanguage = request.Headers.AcceptLanguage.ToString();
var rawIdentifier = $"{userAgent}|{acceptLanguage}";
return rawIdentifier;
}
private async Task GenerateNewTokenForUser(IdentityUser user)
{
var token = GenerateJwtToken(user);
var refreshToken = GenerateRefreshToken();
// Generate device identifier
var deviceIdentifier = GenerateDeviceIdentifier(Request);
// Save refresh token to database.
// Remove any existing refresh tokens for this user and device.
var existingTokens = context.AspNetUserRefreshTokens.Where(t => t.UserId == user.Id && t.DeviceIdentifier == deviceIdentifier);
context.AspNetUserRefreshTokens.RemoveRange(existingTokens);
// Add new refresh token.
await context.AspNetUserRefreshTokens.AddAsync(new AspNetUserRefreshToken
{
UserId = user.Id,
DeviceIdentifier = deviceIdentifier,
Value = refreshToken,
ExpireDate = DateTime.Now.AddDays(30),
CreatedAt = DateTime.Now,
});
await context.SaveChangesAsync();
return new TokenModel() { Token = token, RefreshToken = refreshToken };
}
}