mirror of
https://github.com/kopia/kopia.git
synced 2025-12-23 22:57:50 -05:00
feat(cli): add server user set-password-hash command (#3974)
Objectives: - Facilitate the generation of valid password hashes that can be used with the `server user --user-password` CLI command. - Encapsulate implementation details of password hashing in the `user` package. Adds a new `server user hash-password` CLI command to generate the hash from a supplied password. Modifies the `server user set/add --user-password-hash` CLI command to accept the password hash generated using the `hash-password` command. Adds `GetNewProfile(ctx, rep, username)` helper to move implementation details to the `user` package. Includes CLI and unit tests. Cleans up and removes unused functions.
This commit is contained in:
@@ -4,6 +4,7 @@ type commandServerUser struct {
|
||||
add commandServerUserAddSet
|
||||
set commandServerUserAddSet
|
||||
delete commandServerUserDelete
|
||||
hash commandServerUserHashPassword
|
||||
info commandServerUserInfo
|
||||
list commandServerUserList
|
||||
}
|
||||
@@ -14,6 +15,7 @@ func (c *commandServerUser) setup(svc appServices, parent commandParent) {
|
||||
c.add.setup(svc, cmd, true)
|
||||
c.set.setup(svc, cmd, false)
|
||||
c.delete.setup(svc, cmd)
|
||||
c.hash.setup(svc, cmd)
|
||||
c.info.setup(svc, cmd)
|
||||
c.list.setup(svc, cmd)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"github.com/pkg/errors"
|
||||
@@ -12,11 +12,10 @@
|
||||
)
|
||||
|
||||
type commandServerUserAddSet struct {
|
||||
userAskPassword bool
|
||||
userSetName string
|
||||
userSetPassword string
|
||||
userSetPasswordHashAlgorithm string
|
||||
userSetPasswordHash string
|
||||
userAskPassword bool
|
||||
userSetName string
|
||||
userSetPassword string
|
||||
userSetPasswordHash string
|
||||
|
||||
isNew bool // true == 'add', false == 'update'
|
||||
out textOutput
|
||||
@@ -36,7 +35,6 @@ func (c *commandServerUserAddSet) setup(svc appServices, parent commandParent, i
|
||||
cmd.Flag("ask-password", "Ask for user password").BoolVar(&c.userAskPassword)
|
||||
cmd.Flag("user-password", "Password").StringVar(&c.userSetPassword)
|
||||
cmd.Flag("user-password-hash", "Password hash").StringVar(&c.userSetPasswordHash)
|
||||
cmd.Flag("user-password-hashing-algorithm", "[Experimental] Password hashing algorithm").Hidden().Default(user.DefaultPasswordHashingAlgorithm).EnumVar(&c.userSetPasswordHashAlgorithm, user.PasswordHashingAlgorithms()...)
|
||||
cmd.Arg("username", "Username").Required().StringVar(&c.userSetName)
|
||||
cmd.Action(svc.repositoryWriterAction(c.runServerUserAddSet))
|
||||
|
||||
@@ -44,26 +42,14 @@ func (c *commandServerUserAddSet) setup(svc appServices, parent commandParent, i
|
||||
}
|
||||
|
||||
func (c *commandServerUserAddSet) getExistingOrNewUserProfile(ctx context.Context, rep repo.Repository, username string) (*user.Profile, error) {
|
||||
up, err := user.GetUserProfile(ctx, rep, username)
|
||||
|
||||
if c.isNew {
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil, errors.Errorf("user %q already exists", username)
|
||||
up, err := user.GetNewProfile(ctx, rep, username)
|
||||
|
||||
case errors.Is(err, user.ErrUserNotFound):
|
||||
passwordHashVersion, err := user.GetPasswordHashVersion(c.userSetPasswordHashAlgorithm)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get password hash version")
|
||||
}
|
||||
|
||||
return &user.Profile{
|
||||
Username: username,
|
||||
PasswordHashVersion: passwordHashVersion,
|
||||
}, nil
|
||||
}
|
||||
return up, errors.Wrap(err, "error getting new user profile")
|
||||
}
|
||||
|
||||
up, err := user.GetUserProfile(ctx, rep, username)
|
||||
|
||||
return up, errors.Wrap(err, "error getting user profile")
|
||||
}
|
||||
|
||||
@@ -85,29 +71,18 @@ func (c *commandServerUserAddSet) runServerUserAddSet(ctx context.Context, rep r
|
||||
}
|
||||
}
|
||||
|
||||
if p := c.userSetPasswordHash; p != "" {
|
||||
ph, err := base64.StdEncoding.DecodeString(p)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "invalid password hash, must be valid base64 string")
|
||||
if ph := c.userSetPasswordHash; ph != "" {
|
||||
if err := up.SetPasswordHash(ph); err != nil {
|
||||
return errors.Wrap(err, "error setting password hash")
|
||||
}
|
||||
|
||||
up.PasswordHash = ph
|
||||
changed = true
|
||||
}
|
||||
|
||||
if up.PasswordHash == nil || c.userAskPassword {
|
||||
pwd, err := askPass(c.out.stdout(), "Enter new password for user "+username+": ")
|
||||
pwd, err := askConfirmPass(c.out.stdout(), "Enter new password for user "+username+": ")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error asking for password")
|
||||
}
|
||||
|
||||
pwd2, err := askPass(c.out.stdout(), "Re-enter new password for verification: ")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error asking for password")
|
||||
}
|
||||
|
||||
if pwd != pwd2 {
|
||||
return errors.Wrap(err, "passwords don't match")
|
||||
return err
|
||||
}
|
||||
|
||||
changed = true
|
||||
@@ -132,3 +107,21 @@ func (c *commandServerUserAddSet) runServerUserAddSet(ctx context.Context, rep r
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func askConfirmPass(out io.Writer, initialPrompt string) (string, error) {
|
||||
pwd, err := askPass(out, initialPrompt)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error asking for password")
|
||||
}
|
||||
|
||||
pwd2, err := askPass(out, "Re-enter password for verification: ")
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error asking for password")
|
||||
}
|
||||
|
||||
if pwd != pwd2 {
|
||||
return "", errors.Wrap(err, "passwords don't match")
|
||||
}
|
||||
|
||||
return pwd, nil
|
||||
}
|
||||
|
||||
52
cli/command_user_hash_password.go
Normal file
52
cli/command_user_hash_password.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/kopia/kopia/internal/user"
|
||||
"github.com/kopia/kopia/repo"
|
||||
)
|
||||
|
||||
type commandServerUserHashPassword struct {
|
||||
password string
|
||||
|
||||
out textOutput
|
||||
}
|
||||
|
||||
func (c *commandServerUserHashPassword) setup(svc appServices, parent commandParent) {
|
||||
cmd := parent.Command("hash-password", "Hash a user password that can be passed to the 'server user add/set' command").Alias("hash")
|
||||
|
||||
cmd.Flag("user-password", "Password").StringVar(&c.password)
|
||||
|
||||
cmd.Action(svc.repositoryWriterAction(c.runServerUserHashPassword))
|
||||
|
||||
c.out.setup(svc)
|
||||
}
|
||||
|
||||
// The current implementation does not require a connected repository, thus the
|
||||
// RepositoryWriter parameter is not used. Future implementations will need a
|
||||
// connected repository. To avoid a future incompatible change where the
|
||||
// 'hash-password' command stops working without a connected repository,
|
||||
// a connected repository is required now.
|
||||
func (c *commandServerUserHashPassword) runServerUserHashPassword(ctx context.Context, _ repo.RepositoryWriter) error {
|
||||
if c.password == "" {
|
||||
// when password hash is empty, ask for password
|
||||
pwd, err := askConfirmPass(c.out.stdout(), "Enter password to hash: ")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error getting password")
|
||||
}
|
||||
|
||||
c.password = pwd
|
||||
}
|
||||
|
||||
h, err := user.HashPassword(c.password)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "hashing password")
|
||||
}
|
||||
|
||||
c.out.printStdout("%s\n", h)
|
||||
|
||||
return nil
|
||||
}
|
||||
104
cli/command_user_hash_password_test.go
Normal file
104
cli/command_user_hash_password_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/kopia/kopia/internal/testutil"
|
||||
"github.com/kopia/kopia/tests/testenv"
|
||||
)
|
||||
|
||||
func TestServerUserHashPassword(t *testing.T) {
|
||||
const (
|
||||
userName = "user78"
|
||||
userHost = "client-host"
|
||||
userFull = userName + "@" + userHost
|
||||
)
|
||||
|
||||
runner := testenv.NewInProcRunner(t)
|
||||
e := testenv.NewCLITest(t, testenv.RepoFormatNotImportant, runner)
|
||||
|
||||
e.RunAndExpectSuccess(t, "repo", "create", "filesystem", "--path", e.RepoDir, "--override-username", "server", "--override-hostname", "host")
|
||||
|
||||
t.Cleanup(func() {
|
||||
e.RunAndExpectSuccess(t, "repo", "disconnect")
|
||||
})
|
||||
|
||||
userPassword := "bad-password-" + strconv.Itoa(int(rand.Int31()))
|
||||
|
||||
out := e.RunAndExpectSuccess(t, "server", "users", "hash-password", "--user-password", userPassword)
|
||||
|
||||
require.Len(t, out, 1)
|
||||
|
||||
passwordHash := out[0]
|
||||
require.NotEmpty(t, passwordHash)
|
||||
|
||||
// attempt to create a user with a bad password hash
|
||||
e.RunAndExpectFailure(t, "server", "users", "add", userFull, "--user-password-hash", "bad-base64")
|
||||
|
||||
// create a new user with and set the password using the password hash
|
||||
e.RunAndExpectSuccess(t, "server", "users", "add", userFull, "--user-password-hash", passwordHash)
|
||||
|
||||
// start server to test accessing the server with user created above
|
||||
var sp testutil.ServerParameters
|
||||
|
||||
wait, kill := e.RunAndProcessStderr(t, sp.ProcessOutput,
|
||||
"server", "start",
|
||||
"--address=localhost:0",
|
||||
"--tls-generate-cert",
|
||||
"--random-server-control-password",
|
||||
"--shutdown-grace-period", "100ms",
|
||||
)
|
||||
|
||||
t.Cleanup(func() {
|
||||
kill()
|
||||
wait()
|
||||
t.Log("server stopped")
|
||||
})
|
||||
|
||||
t.Logf("detected server parameters %#v", sp)
|
||||
|
||||
// connect to the server repo using a client with the user created above
|
||||
cr := testenv.NewInProcRunner(t)
|
||||
clientEnv := testenv.NewCLITest(t, testenv.RepoFormatNotImportant, cr)
|
||||
|
||||
delete(clientEnv.Environment, "KOPIA_PASSWORD")
|
||||
|
||||
clientEnv.RunAndExpectSuccess(t, "repo", "connect", "server",
|
||||
"--url", sp.BaseURL,
|
||||
"--server-cert-fingerprint", sp.SHA256Fingerprint,
|
||||
"--override-username", userName,
|
||||
"--override-hostname", userHost,
|
||||
"--password", userPassword)
|
||||
|
||||
clientEnv.RunAndExpectSuccess(t, "repo", "disconnect")
|
||||
|
||||
userPassword2 := "bad-password-" + strconv.Itoa(int(rand.Int31()))
|
||||
|
||||
out = e.RunAndExpectSuccess(t, "server", "users", "hash-password", "--user-password", userPassword2)
|
||||
|
||||
require.Len(t, out, 1)
|
||||
|
||||
passwordHash2 := out[0]
|
||||
require.NotEmpty(t, passwordHash2)
|
||||
|
||||
// set new user password using the password hash and refresh the server
|
||||
e.RunAndExpectSuccess(t, "server", "users", "set", userFull, "--user-password-hash", passwordHash2)
|
||||
e.RunAndExpectSuccess(t, "server", "refresh",
|
||||
"--address", sp.BaseURL,
|
||||
"--server-cert-fingerprint", sp.SHA256Fingerprint,
|
||||
"--server-control-password", sp.ServerControlPassword)
|
||||
|
||||
// attempt connecting with the new password
|
||||
clientEnv.RunAndExpectSuccess(t, "repo", "connect", "server",
|
||||
"--url", sp.BaseURL,
|
||||
"--server-cert-fingerprint", sp.SHA256Fingerprint,
|
||||
"--override-username", userName,
|
||||
"--override-hostname", userHost,
|
||||
"--password", userPassword2)
|
||||
|
||||
clientEnv.RunAndExpectSuccess(t, "repo", "disconnect")
|
||||
}
|
||||
71
internal/user/hash_password.go
Normal file
71
internal/user/hash_password.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type passwordHash struct {
|
||||
PasswordHashVersion int `json:"passwordHashVersion"`
|
||||
PasswordHash []byte `json:"passwordHash"`
|
||||
}
|
||||
|
||||
// HashPassword computes the hash for the given password and an encoded hash
|
||||
// that can be passed to Profile.SetPasswordHash().
|
||||
func HashPassword(password string) (string, error) {
|
||||
const hashVersion = defaultPasswordHashVersion
|
||||
|
||||
salt := make([]byte, passwordHashSaltLength)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return "", errors.Wrap(err, "error generating salt")
|
||||
}
|
||||
|
||||
h, err := computePasswordHash(password, salt, hashVersion)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error hashing password")
|
||||
}
|
||||
|
||||
pwh := passwordHash{
|
||||
PasswordHashVersion: hashVersion,
|
||||
PasswordHash: h,
|
||||
}
|
||||
|
||||
j, err := json.Marshal(pwh)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "error encoding password hash")
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(j), nil
|
||||
}
|
||||
|
||||
func decodeHashedPassword(encodedHash string) (*passwordHash, error) {
|
||||
var h passwordHash
|
||||
|
||||
passwordHashJSON, err := base64.StdEncoding.DecodeString(encodedHash)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decoding password hash")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(passwordHashJSON, &h); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshalling password hash")
|
||||
}
|
||||
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
// validates hashing algorithm and password hash length.
|
||||
func (h *passwordHash) validate() error {
|
||||
if _, err := getPasswordHashAlgorithm(h.PasswordHashVersion); err != nil {
|
||||
return errors.Wrap(err, "invalid password hash version")
|
||||
}
|
||||
|
||||
if len(h.PasswordHash) != passwordHashSaltLength+passwordHashLength {
|
||||
return errors.Errorf("invalid hash length: %v", len(h.PasswordHash))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
85
internal/user/hash_password_test.go
Normal file
85
internal/user/hash_password_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
petname "github.com/dustinkirkland/golang-petname"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHashPassword_encoding(t *testing.T) {
|
||||
bogusPassword := petname.Generate(2, "+")
|
||||
|
||||
h, err := HashPassword(bogusPassword)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, h)
|
||||
|
||||
// roundtrip
|
||||
ph, err := decodeHashedPassword(h)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ph)
|
||||
require.NotZero(t, ph.PasswordHashVersion)
|
||||
require.NotEmpty(t, ph.PasswordHash)
|
||||
|
||||
p := Profile{
|
||||
PasswordHashVersion: ph.PasswordHashVersion,
|
||||
PasswordHash: ph.PasswordHash,
|
||||
}
|
||||
|
||||
valid, err := p.IsValidPassword(bogusPassword)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, valid)
|
||||
}
|
||||
|
||||
func TestPasswordHashValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
ph passwordHash
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
ph: passwordHash{
|
||||
PasswordHashVersion: -3,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
ph: passwordHash{
|
||||
PasswordHashVersion: defaultPasswordHashVersion,
|
||||
// empty PasswordHash
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
ph: passwordHash{
|
||||
PasswordHashVersion: defaultPasswordHashVersion,
|
||||
// PasswordHash with invalid length
|
||||
PasswordHash: []byte{'z', 'a'},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
ph: passwordHash{
|
||||
PasswordHashVersion: defaultPasswordHashVersion,
|
||||
PasswordHash: make([]byte, passwordHashSaltLength+passwordHashLength),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range cases {
|
||||
t.Run("i_"+strconv.Itoa(i), func(t *testing.T) {
|
||||
gotErr := tc.ph.validate()
|
||||
if tc.expectError {
|
||||
require.Error(t, gotErr)
|
||||
} else {
|
||||
require.NoError(t, gotErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
18
internal/user/password_hashing_version.go
Normal file
18
internal/user/password_hashing_version.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package user
|
||||
|
||||
import "github.com/pkg/errors"
|
||||
|
||||
// defaultPasswordHashVersion is the default scheme used for user password hashing.
|
||||
const defaultPasswordHashVersion = ScryptHashVersion
|
||||
|
||||
// getPasswordHashAlgorithm returns the password hash algorithm given a version.
|
||||
func getPasswordHashAlgorithm(passwordHashVersion int) (string, error) {
|
||||
switch passwordHashVersion {
|
||||
case ScryptHashVersion:
|
||||
return scryptHashAlgorithm, nil
|
||||
case Pbkdf2HashVersion:
|
||||
return pbkdf2HashAlgorithm, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash version (%d)", passwordHashVersion)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package user
|
||||
|
||||
// DefaultPasswordHashingAlgorithm is the default password hashing scheme for user profiles.
|
||||
const DefaultPasswordHashingAlgorithm = scryptHashAlgorithm
|
||||
|
||||
// PasswordHashingAlgorithms returns the supported algorithms for user password hashing.
|
||||
func PasswordHashingAlgorithms() []string {
|
||||
return []string{scryptHashAlgorithm, pbkdf2HashAlgorithm}
|
||||
}
|
||||
@@ -33,8 +33,9 @@ func TestSaltLengthIsSupported(t *testing.T) {
|
||||
const badPwd = "password"
|
||||
var salt [passwordHashSaltLength]byte
|
||||
|
||||
for _, h := range PasswordHashingAlgorithms() {
|
||||
_, err := computePasswordHash(badPwd, salt[:], h)
|
||||
for _, v := range []int{ScryptHashVersion, Pbkdf2HashVersion} {
|
||||
h, err := computePasswordHash(badPwd, salt[:], v)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
// ErrUserNotFound is returned to indicate that a user was not found in the system.
|
||||
var ErrUserNotFound = errors.New("user not found")
|
||||
|
||||
// ErrUserAlreadyExists indicates that a user already exist in the system when attempting to create a new one.
|
||||
var ErrUserAlreadyExists = errors.New("user already exists")
|
||||
|
||||
// LoadProfileMap returns the map of all users profiles in the repository by username, using old map as a cache.
|
||||
func LoadProfileMap(ctx context.Context, rep repo.Repository, old map[string]*Profile) (map[string]*Profile, error) {
|
||||
if rep == nil {
|
||||
@@ -99,6 +102,32 @@ func GetUserProfile(ctx context.Context, r repo.Repository, username string) (*P
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetNewProfile returns a profile for a new user with the given username.
|
||||
// Returns ErrUserAlreadyExists when the user already exists.
|
||||
func GetNewProfile(ctx context.Context, r repo.Repository, username string) (*Profile, error) {
|
||||
if err := ValidateUsername(username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manifests, err := r.FindManifests(ctx, map[string]string{
|
||||
manifest.TypeLabelKey: ManifestType,
|
||||
UsernameAtHostnameLabel: username,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error looking for user profile")
|
||||
}
|
||||
|
||||
if len(manifests) != 0 {
|
||||
return nil, errors.Wrap(ErrUserAlreadyExists, username)
|
||||
}
|
||||
|
||||
return &Profile{
|
||||
Username: username,
|
||||
PasswordHashVersion: defaultPasswordHashVersion,
|
||||
},
|
||||
nil
|
||||
}
|
||||
|
||||
// validUsernameRegexp matches username@hostname where both username and hostname consist of
|
||||
// lowercase letters, digits or dashes, underscores or period characters.
|
||||
var validUsernameRegexp = regexp.MustCompile(`^[a-z0-9\-_.]+@[a-z0-9\-_.]+$`)
|
||||
|
||||
@@ -59,6 +59,29 @@ func TestUserManager(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNewProfile(t *testing.T) {
|
||||
ctx, env := repotesting.NewEnvironment(t, repotesting.FormatNotImportant)
|
||||
|
||||
p, err := user.GetNewProfile(ctx, env.RepositoryWriter, "alice@somehost")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p)
|
||||
|
||||
err = p.SetPassword("badpassword")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.SetUserProfile(ctx, env.RepositoryWriter, p)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err = user.GetNewProfile(ctx, env.RepositoryWriter, p.Username)
|
||||
require.ErrorIs(t, err, user.ErrUserAlreadyExists)
|
||||
require.Nil(t, p)
|
||||
|
||||
p, err = user.GetNewProfile(ctx, env.RepositoryWriter, "nonexisting@somehost")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p)
|
||||
}
|
||||
|
||||
func TestValidateUsername_Valid(t *testing.T) {
|
||||
cases := []string{
|
||||
"foo@bar",
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/kopia/kopia/repo/manifest"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,56 +33,34 @@ func (p *Profile) SetPassword(password string) error {
|
||||
return p.setPassword(password)
|
||||
}
|
||||
|
||||
// IsValidPassword determines whether the password is valid for a given user.
|
||||
func (p *Profile) IsValidPassword(password string) (bool, error) {
|
||||
var invalidProfile bool
|
||||
|
||||
var passwordHashAlgorithm string
|
||||
|
||||
var err error
|
||||
|
||||
if p == nil {
|
||||
invalidProfile = true
|
||||
} else {
|
||||
passwordHashAlgorithm, err = getPasswordHashAlgorithm(p.PasswordHashVersion)
|
||||
if err != nil {
|
||||
invalidProfile = true
|
||||
}
|
||||
// SetPasswordHash decodes and validates encodedhash, if it is a valid hash
|
||||
// then it sets it as the password hash for the user profile.
|
||||
func (p *Profile) SetPasswordHash(encodedHash string) error {
|
||||
ph, err := decodeHashedPassword(encodedHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if invalidProfile {
|
||||
algorithms := PasswordHashingAlgorithms()
|
||||
// if the user profile is invalid, either a non-existing user name or password
|
||||
// hash version, then return false but use the same amount of time as when we
|
||||
// compare against valid user to avoid revealing whether the user account exists.
|
||||
_, err := isValidPassword(password, dummyHashThatNeverMatchesAnyPassword, algorithms[rand.Intn(len(algorithms))]) //nolint:gosec
|
||||
if err := ph.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.PasswordHashVersion = ph.PasswordHashVersion
|
||||
p.PasswordHash = ph.PasswordHash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidPassword determines whether the password is valid for a given user.
|
||||
func (p *Profile) IsValidPassword(password string) (bool, error) {
|
||||
if p == nil {
|
||||
// return false when the user profile does not exist,
|
||||
// but use the same amount of time as when checking the password for a
|
||||
// valid user to avoid revealing whether the account exists.
|
||||
_, err := isValidPassword(password, dummyHashThatNeverMatchesAnyPassword, defaultPasswordHashVersion)
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
return isValidPassword(password, p.PasswordHash, passwordHashAlgorithm)
|
||||
}
|
||||
|
||||
// getPasswordHashAlgorithm returns the password hash algorithm given a version.
|
||||
func getPasswordHashAlgorithm(passwordHashVersion int) (string, error) {
|
||||
switch passwordHashVersion {
|
||||
case ScryptHashVersion:
|
||||
return scryptHashAlgorithm, nil
|
||||
case Pbkdf2HashVersion:
|
||||
return pbkdf2HashAlgorithm, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash version (%d)", passwordHashVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// GetPasswordHashVersion returns the password hash version given an algorithm.
|
||||
func GetPasswordHashVersion(passwordHashAlgorithm string) (int, error) {
|
||||
switch passwordHashAlgorithm {
|
||||
case scryptHashAlgorithm:
|
||||
return ScryptHashVersion, nil
|
||||
case pbkdf2HashAlgorithm:
|
||||
return Pbkdf2HashVersion, nil
|
||||
default:
|
||||
return 0, errors.Errorf("unsupported hash algorithm (%s)", passwordHashAlgorithm)
|
||||
}
|
||||
return isValidPassword(password, p.PasswordHash, p.PasswordHashVersion)
|
||||
}
|
||||
|
||||
@@ -24,23 +24,25 @@ func initDummyHash() []byte {
|
||||
}
|
||||
|
||||
func (p *Profile) setPassword(password string) error {
|
||||
passwordHashAlgorithm, err := getPasswordHashAlgorithm(p.PasswordHashVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
salt := make([]byte, passwordHashSaltLength)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return errors.Wrap(err, "error generating salt")
|
||||
}
|
||||
|
||||
p.PasswordHash, err = computePasswordHash(password, salt, passwordHashAlgorithm)
|
||||
var err error
|
||||
|
||||
p.PasswordHash, err = computePasswordHash(password, salt, p.PasswordHashVersion)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func computePasswordHash(password string, salt []byte, keyDerivationAlgorithm string) ([]byte, error) {
|
||||
key, err := crypto.DeriveKeyFromPassword(password, salt, passwordHashLength, keyDerivationAlgorithm)
|
||||
func computePasswordHash(password string, salt []byte, passwordHashVersion int) ([]byte, error) {
|
||||
hashingAlgo, err := getPasswordHashAlgorithm(passwordHashVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := crypto.DeriveKeyFromPassword(password, salt, passwordHashLength, hashingAlgo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error hashing password")
|
||||
}
|
||||
@@ -50,14 +52,14 @@ func computePasswordHash(password string, salt []byte, keyDerivationAlgorithm st
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func isValidPassword(password string, hashedPassword []byte, keyDerivationAlgorithm string) (bool, error) {
|
||||
func isValidPassword(password string, hashedPassword []byte, passwordHashVersion int) (bool, error) {
|
||||
if len(hashedPassword) != passwordHashSaltLength+passwordHashLength {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
salt := hashedPassword[0:passwordHashSaltLength]
|
||||
|
||||
h, err := computePasswordHash(password, salt, keyDerivationAlgorithm)
|
||||
h, err := computePasswordHash(password, salt, passwordHashVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestBadPasswordHashVersion(t *testing.T) {
|
||||
isValid, err = p.IsValidPassword("foo")
|
||||
|
||||
require.False(t, isValid, "password unexpectedly valid!")
|
||||
require.NoError(t, err)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNilUserProfile(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user