//----------------------------------------------------------------------- // // Copyright (c) lanedirt. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // //----------------------------------------------------------------------- namespace AliasVault.Api.Controllers; using System.IdentityModel.Tokens.Jwt; using System.Security.Claims; using System.Security.Cryptography; using System.Text; using AliasDb; using AliasVault.Shared.Models; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; using Microsoft.IdentityModel.Tokens; /// /// Auth controller for handling authentication. /// /// AliasDbContext instance. /// UserManager instance. /// SignInManager instance. /// IConfiguration instance. [Route("api/[controller]")] [ApiController] public class AuthController(AliasDbContext context, UserManager userManager, SignInManager signInManager, IConfiguration configuration) : ControllerBase { /// /// Login endpoint used to process login attempt using credentials. /// /// Login model. /// IActionResult. [HttpPost("login")] public async Task Login([FromBody] LoginModel model) { var user = await userManager.FindByEmailAsync(model.Email); if (user != null && await userManager.CheckPasswordAsync(user, model.Password)) { var tokenModel = await GenerateNewTokenForUser(user); return Ok(tokenModel); } return Unauthorized(); } /// /// Refresh endpoint used to refresh an expired access token using a valid refresh token. /// /// Token model. /// IActionResult. [HttpPost("refresh")] public async Task Refresh([FromBody] TokenModel tokenModel) { var principal = GetPrincipalFromExpiredToken(tokenModel.Token); if (principal.FindFirst(ClaimTypes.NameIdentifier)?.Value == null) { return Unauthorized("User not found (email-1)"); } var user = await userManager.FindByIdAsync(principal.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? string.Empty); if (user == null) { return Unauthorized("User not found (email-2)"); } // Check if the refresh token is valid. // Remove any existing refresh tokens for this user and device. var deviceIdentifier = GenerateDeviceIdentifier(Request); var existingToken = context.AspNetUserRefreshTokens.Where(t => t.UserId == user.Id && t.DeviceIdentifier == deviceIdentifier).FirstOrDefault(); if (existingToken == null || existingToken.Value != tokenModel.RefreshToken || existingToken.ExpireDate < DateTime.Now) { return Unauthorized("Refresh token expired"); } // Remove the existing refresh token. context.AspNetUserRefreshTokens.Remove(existingToken); // Generate a new refresh token to replace the old one. var newRefreshToken = GenerateRefreshToken(); // Add new refresh token. await context.AspNetUserRefreshTokens.AddAsync(new AspNetUserRefreshToken { UserId = user.Id, DeviceIdentifier = deviceIdentifier, Value = newRefreshToken, ExpireDate = DateTime.Now.AddDays(30), CreatedAt = DateTime.Now, }); await context.SaveChangesAsync(); var token = GenerateJwtToken(user); return Ok(new TokenModel() { Token = token, RefreshToken = newRefreshToken }); } /// /// Revoke endpoint used to revoke a refresh token. /// /// Token model. /// IActionResult. [HttpPost("revoke")] public async Task Revoke([FromBody] TokenModel model) { var principal = GetPrincipalFromExpiredToken(model.Token); if (principal.FindFirst(ClaimTypes.NameIdentifier)?.Value == null) { return Unauthorized("User not found (email-1)"); } var user = await userManager.FindByIdAsync(principal.FindFirst(ClaimTypes.NameIdentifier)?.Value ?? string.Empty); if (user == null) { return Unauthorized("User not found (email-2)"); } // Check if the refresh token is valid. var deviceIdentifier = GenerateDeviceIdentifier(Request); var existingToken = context.AspNetUserRefreshTokens.Where(t => t.UserId == user.Id && t.DeviceIdentifier == deviceIdentifier).FirstOrDefault(); if (existingToken == null || existingToken.Value != model.RefreshToken) { return Unauthorized("Invalid refresh token"); } // Remove the existing refresh token. context.AspNetUserRefreshTokens.Remove(existingToken); await context.SaveChangesAsync(); return Ok("Refresh token revoked successfully"); } /// /// Register endpoint used to register a new user. /// /// Register model. /// IActionResult. [HttpPost("register")] public async Task Register([FromBody] RegisterModel model) { var user = new IdentityUser { UserName = model.Email, Email = model.Email }; var result = await userManager.CreateAsync(user, model.Password); if (result.Succeeded) { // When a user is registered, they are automatically signed in. await signInManager.SignInAsync(user, isPersistent: false); // Return the token. var tokenModel = await GenerateNewTokenForUser(user); return Ok(tokenModel); } else { return BadRequest(result.Errors); } } private string GenerateJwtToken(IdentityUser user) { var claims = new List { new(ClaimTypes.NameIdentifier, user.Id ?? string.Empty), new(ClaimTypes.Name, user.UserName ?? string.Empty), new(ClaimTypes.Email, user.Email ?? string.Empty), new(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()), }; var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(configuration["Jwt:Key"] ?? string.Empty)); var creds = new SigningCredentials(key, SecurityAlgorithms.HmacSha256); var token = new JwtSecurityToken( issuer: configuration["Jwt:Issuer"] ?? string.Empty, audience: configuration["Jwt:Issuer"] ?? string.Empty, claims: claims, expires: DateTime.Now.AddMinutes(30), signingCredentials: creds); return new JwtSecurityTokenHandler().WriteToken(token); } private string GenerateRefreshToken() { var randomNumber = new byte[32]; using var rng = RandomNumberGenerator.Create(); rng.GetBytes(randomNumber); return Convert.ToBase64String(randomNumber); } private ClaimsPrincipal GetPrincipalFromExpiredToken(string token) { var tokenValidationParameters = new TokenValidationParameters { ValidateAudience = false, ValidateIssuer = false, ValidateIssuerSigningKey = true, IssuerSigningKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(configuration["Jwt:Key"] ?? string.Empty)), ValidateLifetime = false, }; var tokenHandler = new JwtSecurityTokenHandler(); var principal = tokenHandler.ValidateToken(token, tokenValidationParameters, out SecurityToken securityToken); if (securityToken is not JwtSecurityToken jwtSecurityToken || !jwtSecurityToken.Header.Alg.Equals(SecurityAlgorithms.HmacSha256, StringComparison.InvariantCultureIgnoreCase)) { throw new SecurityTokenException("Invalid token"); } return principal; } private string GenerateDeviceIdentifier(HttpRequest request) { // TODO: Add more headers to the device identifier or let client send a unique identifier instead. var userAgent = request.Headers.UserAgent.ToString(); var acceptLanguage = request.Headers.AcceptLanguage.ToString(); var rawIdentifier = $"{userAgent}|{acceptLanguage}"; return rawIdentifier; } private async Task GenerateNewTokenForUser(IdentityUser user) { var token = GenerateJwtToken(user); var refreshToken = GenerateRefreshToken(); // Generate device identifier var deviceIdentifier = GenerateDeviceIdentifier(Request); // Save refresh token to database. // Remove any existing refresh tokens for this user and device. var existingTokens = context.AspNetUserRefreshTokens.Where(t => t.UserId == user.Id && t.DeviceIdentifier == deviceIdentifier); context.AspNetUserRefreshTokens.RemoveRange(existingTokens); // Add new refresh token. await context.AspNetUserRefreshTokens.AddAsync(new AspNetUserRefreshToken { UserId = user.Id, DeviceIdentifier = deviceIdentifier, Value = refreshToken, ExpireDate = DateTime.Now.AddDays(30), CreatedAt = DateTime.Now, }); await context.SaveChangesAsync(); return new TokenModel() { Token = token, RefreshToken = refreshToken }; } }