From 5fa39fe5eb6347cada69c4839fcaea7e6e97a6fe Mon Sep 17 00:00:00 2001 From: Sirish Bathina Date: Wed, 10 Apr 2024 14:56:13 -1000 Subject: [PATCH] feat(general): User_profile_add_set cli changes (#3770) * User profile add set cli changes * Add additional test * get salt length from key deriver * Fixes for failing tests * after shikhars review * lint fixes --- cli/command_user_add_set.go | 20 +++---- internal/auth/authn_repo.go | 3 +- internal/auth/authn_repo_test.go | 53 ++++++++++++++++--- internal/crypto/key_derivation_nontest.go | 50 +++++++++++++----- internal/crypto/key_derivation_testing.go | 28 +++++++--- internal/crypto/pbkdf.go | 51 ++++++++++++------ internal/crypto/scrypt.go | 51 +++++++++++++++--- internal/user/user_profile.go | 34 +++++++----- internal/user/user_profile_hash_v1.go | 49 ++++++++++------- internal/user/user_profile_test.go | 64 +++++++++++++++++++---- 10 files changed, 301 insertions(+), 102 deletions(-) diff --git a/cli/command_user_add_set.go b/cli/command_user_add_set.go index 14589db56..9a72b5437 100644 --- a/cli/command_user_add_set.go +++ b/cli/command_user_add_set.go @@ -13,11 +13,11 @@ ) type commandServerUserAddSet struct { - userAskPassword bool - userSetName string - userSetPassword string - userSetPasswordHashVersion int - userSetPasswordHash string + userAskPassword bool + userSetName string + userSetPassword string + keyDerivationAlgorithm string + userSetPasswordHash string isNew bool // true == 'add', false == 'update' out textOutput @@ -37,7 +37,7 @@ func (c *commandServerUserAddSet) setup(svc appServices, parent commandParent, i cmd.Flag("ask-password", "Ask for user password").BoolVar(&c.userAskPassword) cmd.Flag("user-password", "Password").StringVar(&c.userSetPassword) cmd.Flag("user-password-hash", "Password hash").StringVar(&c.userSetPasswordHash) - cmd.Flag("user-password-hash-version", "Password hash version").Default("1").IntVar(&c.userSetPasswordHashVersion) + cmd.Flag("key-derivation-algorithm", "Key derivation algorithm").Default(crypto.DefaultKeyDerivationAlgorithm).EnumVar(&c.keyDerivationAlgorithm, crypto.AllowedKeyDerivationAlgorithms()...) cmd.Arg("username", "Username").Required().StringVar(&c.userSetName) cmd.Action(svc.repositoryWriterAction(c.runServerUserAddSet)) @@ -54,7 +54,8 @@ func (c *commandServerUserAddSet) getExistingOrNewUserProfile(ctx context.Contex case errors.Is(err, user.ErrUserNotFound): return &user.Profile{ - Username: username, + Username: username, + KeyDerivationAlgorithm: c.keyDerivationAlgorithm, }, nil } } @@ -75,7 +76,7 @@ func (c *commandServerUserAddSet) runServerUserAddSet(ctx context.Context, rep r if p := c.userSetPassword; p != "" { changed = true - if err := up.SetPassword(p, crypto.DefaultKeyDerivationAlgorithm); err != nil { + if err := up.SetPassword(p); err != nil { return errors.Wrap(err, "error setting password") } } @@ -86,7 +87,6 @@ func (c *commandServerUserAddSet) runServerUserAddSet(ctx context.Context, rep r return errors.Wrap(err, "invalid password hash, must be valid base64 string") } - up.PasswordHashVersion = c.userSetPasswordHashVersion up.PasswordHash = ph changed = true } @@ -108,7 +108,7 @@ func (c *commandServerUserAddSet) runServerUserAddSet(ctx context.Context, rep r changed = true - if err := up.SetPassword(pwd, crypto.DefaultKeyDerivationAlgorithm); err != nil { + if err := up.SetPassword(pwd); err != nil { return errors.Wrap(err, "error setting password") } } diff --git a/internal/auth/authn_repo.go b/internal/auth/authn_repo.go index 67466e655..9df1373dd 100644 --- a/internal/auth/authn_repo.go +++ b/internal/auth/authn_repo.go @@ -6,7 +6,6 @@ "time" "github.com/kopia/kopia/internal/clock" - "github.com/kopia/kopia/internal/crypto" "github.com/kopia/kopia/internal/user" "github.com/kopia/kopia/repo" ) @@ -52,7 +51,7 @@ func (ac *repositoryUserAuthenticator) IsValid(ctx context.Context, rep repo.Rep // IsValidPassword can be safely called on nil and the call will take as much time as for a valid user // thus not revealing anything about whether the user exists. - return ac.userProfiles[username].IsValidPassword(password, crypto.DefaultKeyDerivationAlgorithm) + return ac.userProfiles[username].IsValidPassword(password) } func (ac *repositoryUserAuthenticator) Refresh(ctx context.Context) error { diff --git a/internal/auth/authn_repo_test.go b/internal/auth/authn_repo_test.go index 27cd9e458..d145884b2 100644 --- a/internal/auth/authn_repo_test.go +++ b/internal/auth/authn_repo_test.go @@ -19,13 +19,45 @@ func TestRepositoryAuthenticator(t *testing.T) { require.NoError(t, repo.WriteSession(ctx, env.Repository, repo.WriteSessionOptions{}, func(ctx context.Context, w repo.RepositoryWriter) error { - p := &user.Profile{ - Username: "user1@host1", + for _, tc := range []struct { + profile *user.Profile + password string + }{ + { + profile: &user.Profile{ + Username: "user1@host1", + PasswordHashVersion: crypto.HashVersion1, + }, + password: "password1", + }, + { + profile: &user.Profile{ + Username: "user2@host2", + KeyDerivationAlgorithm: crypto.ScryptAlgorithm, + }, + password: "password2", + }, + { + profile: &user.Profile{ + Username: "user3@host3", + }, + password: "password3", + }, + { + profile: &user.Profile{ + Username: "user4@host4", + KeyDerivationAlgorithm: crypto.Pbkdf2Algorithm, + }, + password: "password4", + }, + } { + tc.profile.SetPassword(tc.password) + err := user.SetUserProfile(ctx, w, tc.profile) + if err != nil { + return err + } } - - p.SetPassword("password1", crypto.DefaultKeyDerivationAlgorithm) - - return user.SetUserProfile(ctx, w, p) + return nil })) verifyRepoAuthenticator(ctx, t, a, env.Repository, "user1@host1", "password1", true) @@ -33,6 +65,15 @@ func(ctx context.Context, w repo.RepositoryWriter) error { verifyRepoAuthenticator(ctx, t, a, env.Repository, "user1@host1", "password11", false) verifyRepoAuthenticator(ctx, t, a, env.Repository, "user1@host1a", "password1", false) verifyRepoAuthenticator(ctx, t, a, env.Repository, "user1@host1a", "password1a", false) + + // Test for password with KeyDerivationSet + verifyRepoAuthenticator(ctx, t, a, env.Repository, "user2@host2", "password2", true) + + // Test for User with neither key derivation or PasswordHashVersion set + verifyRepoAuthenticator(ctx, t, a, env.Repository, "user3@host3", "password3", false) + + // Test for PBKDF2 key derivation + verifyRepoAuthenticator(ctx, t, a, env.Repository, "user4@host4", "password4", true) } func verifyRepoAuthenticator(ctx context.Context, t *testing.T, a auth.Authenticator, r repo.Repository, username, password string, want bool) { diff --git a/internal/crypto/key_derivation_nontest.go b/internal/crypto/key_derivation_nontest.go index 9e30c8252..e8b8c7659 100644 --- a/internal/crypto/key_derivation_nontest.go +++ b/internal/crypto/key_derivation_nontest.go @@ -4,38 +4,64 @@ package crypto import ( + "fmt" + "github.com/pkg/errors" ) const ( // MasterKeyLength describes the length of the master key. MasterKeyLength = 32 - - // ScryptAlgorithm is the key for the scrypt algorithm. - ScryptAlgorithm = "scrypt-65536-8-1" - // Pbkdf2Algorithm is the key for the pbkdf algorithm. - Pbkdf2Algorithm = "pbkdf2" ) // DefaultKeyDerivationAlgorithm is the key derivation algorithm for new configurations. const DefaultKeyDerivationAlgorithm = ScryptAlgorithm -type keyDerivationFunc func(password string, salt []byte) ([]byte, error) +// KeyDeriver is an interface that contains methods for deriving a key from a password. +type KeyDeriver interface { + DeriveKeyFromPassword(password string, salt []byte) ([]byte, error) + RecommendedSaltLength() int +} //nolint:gochecknoglobals -var keyDerivers = map[string]keyDerivationFunc{} +var keyDerivers = map[string]KeyDeriver{} + +// RegisterKeyDerivers registers various key derivation functions. +func RegisterKeyDerivers(name string, keyDeriver KeyDeriver) { + if _, ok := keyDerivers[name]; ok { + panic(fmt.Sprintf("key deriver (%s) is already registered", name)) + } -// RegisterKeyDerivationFunc registers various key derivation functions. -func RegisterKeyDerivationFunc(name string, keyDeriver keyDerivationFunc) { keyDerivers[name] = keyDeriver } // DeriveKeyFromPassword derives encryption key using the provided password and per-repository unique ID. func DeriveKeyFromPassword(password string, salt []byte, algorithm string) ([]byte, error) { - kdFunc, ok := keyDerivers[algorithm] + kd, ok := keyDerivers[algorithm] if !ok { - return nil, errors.Errorf("unsupported key algorithm: %v", algorithm) + return nil, errors.Errorf("unsupported key algorithm: %v, supported algorithms %v", algorithm, AllowedKeyDerivationAlgorithms()) } - return kdFunc(password, salt) + //nolint:wrapcheck + return kd.DeriveKeyFromPassword(password, salt) +} + +// RecommendedSaltLength returns the recommended salt length of a given key derivation algorithm. +func RecommendedSaltLength(algorithm string) (int, error) { + kd, ok := keyDerivers[algorithm] + if !ok { + return 0, errors.Errorf("unsupported key algorithm: %v, supported algorithms %v", algorithm, AllowedKeyDerivationAlgorithms()) + } + + return kd.RecommendedSaltLength(), nil +} + +// AllowedKeyDerivationAlgorithms returns a slice of the allowed key derivation algorithms. +func AllowedKeyDerivationAlgorithms() []string { + kdAlgorithms := make([]string, 0, len(keyDerivers)) + for k := range keyDerivers { + kdAlgorithms = append(kdAlgorithms, k) + } + + return kdAlgorithms } diff --git a/internal/crypto/key_derivation_testing.go b/internal/crypto/key_derivation_testing.go index c56bbb55c..acbbe74b3 100644 --- a/internal/crypto/key_derivation_testing.go +++ b/internal/crypto/key_derivation_testing.go @@ -9,20 +9,28 @@ "github.com/pkg/errors" ) -// DefaultKeyDerivationAlgorithm is the key derivation algorithm for new configurations. -const DefaultKeyDerivationAlgorithm = "testing-only-insecure" +const ( + // DefaultKeyDerivationAlgorithm is the key derivation algorithm for new configurations. + DefaultKeyDerivationAlgorithm = "testing-only-insecure" -// MasterKeyLength describes the length of the master key. -const MasterKeyLength = 32 + // MasterKeyLength describes the length of the master key. + MasterKeyLength = 32 + + V1SaltLength = 32 + HashVersion1 = 1 // this translates to Scrypt KeyDerivationAlgorithm + ScryptAlgorithm = "scrypt-65536-8-1" + Pbkdf2Algorithm = "pbkdf2" +) // DeriveKeyFromPassword derives encryption key using the provided password and per-repository unique ID. func DeriveKeyFromPassword(password string, salt []byte, algorithm string) ([]byte, error) { const masterKeySize = 32 switch algorithm { - case DefaultKeyDerivationAlgorithm: + case DefaultKeyDerivationAlgorithm, ScryptAlgorithm, Pbkdf2Algorithm: h := sha256.New() - if _, err := h.Write([]byte(password)); err != nil { + // Adjust password so that we get a different key for each algorithm + if _, err := h.Write([]byte(password + algorithm)); err != nil { return nil, err } @@ -32,3 +40,11 @@ func DeriveKeyFromPassword(password string, salt []byte, algorithm string) ([]by return nil, errors.Errorf("unsupported key algorithm: %v", algorithm) } } + +func RecommendedSaltLength(algorithm string) (int, error) { + return V1SaltLength, nil +} + +func AllowedKeyDerivationAlgorithms() []string { + return []string{DefaultKeyDerivationAlgorithm} +} diff --git a/internal/crypto/pbkdf.go b/internal/crypto/pbkdf.go index e49fc931d..b89693fab 100644 --- a/internal/crypto/pbkdf.go +++ b/internal/crypto/pbkdf.go @@ -10,23 +10,44 @@ "golang.org/x/crypto/pbkdf2" ) -// The NIST recommended iterations for PBKDF2 with SHA256 hash is 600,000. -const pbkdf2Sha256Iterations = 1<<20 - 1<<18 // 786,432 +const ( + // The NIST recommended minimum size for a salt for pbkdf2 is 16 bytes. + // + // TBD: However, a good rule of thumb is to use a salt that is the same size + // as the output of the hash function. For example, the output of SHA256 + // is 256 bits (32 bytes), so the salt should be at least 32 random bytes. + // See: https://crackstation.net/hashing-security.htm + minPbkdfSha256SaltSize = 32 // size in bytes == 128 bits -// The NIST recommended minimum size for a salt for pbkdf2 is 16 bytes. -// -// TBD: However, a good rule of thumb is to use a salt that is the same size -// as the output of the hash function. For example, the output of SHA256 -// is 256 bits (32 bytes), so the salt should be at least 32 random bytes. -// See: https://crackstation.net/hashing-security.htm -const minPbkdfSha256SaltSize = 16 // size in bytes == 128 bits + // The NIST recommended iterations for PBKDF2 with SHA256 hash is 600,000. + pbkdf2Sha256Iterations = 1<<20 - 1<<18 // 786,432 + + // Pbkdf2Algorithm is the key for the pbkdf algorithm. + Pbkdf2Algorithm = "pbkdf2" +) func init() { - RegisterKeyDerivationFunc(Pbkdf2Algorithm, func(password string, salt []byte) ([]byte, error) { - if len(salt) < minPbkdfSha256SaltSize { - return nil, errors.Errorf("required salt size is atleast %d bytes", minPbkdfSha256SaltSize) - } - - return pbkdf2.Key([]byte(password), salt, pbkdf2Sha256Iterations, MasterKeyLength, sha256.New), nil + RegisterKeyDerivers(Pbkdf2Algorithm, &pbkdf2KeyDeriver{ + iterations: pbkdf2Sha256Iterations, + recommendedSaltLength: minPbkdfSha256SaltSize, + minSaltLength: minPbkdfSha256SaltSize, }) } + +type pbkdf2KeyDeriver struct { + iterations int + recommendedSaltLength int + minSaltLength int +} + +func (s *pbkdf2KeyDeriver) DeriveKeyFromPassword(password string, salt []byte) ([]byte, error) { + if len(salt) < s.minSaltLength { + return nil, errors.Errorf("required salt size is atleast %d bytes", s.minSaltLength) + } + + return pbkdf2.Key([]byte(password), salt, s.iterations, MasterKeyLength, sha256.New), nil +} + +func (s *pbkdf2KeyDeriver) RecommendedSaltLength() int { + return s.recommendedSaltLength +} diff --git a/internal/crypto/scrypt.go b/internal/crypto/scrypt.go index ab4d5da46..1758215dc 100644 --- a/internal/crypto/scrypt.go +++ b/internal/crypto/scrypt.go @@ -16,15 +16,50 @@ // is 256 bits (32 bytes), so the salt should be at least 32 random bytes. // Scrypt uses a SHA256 hash function. // https://crackstation.net/hashing-security.htm -const minScryptSha256SaltSize = 16 // size in bytes == 128 bits +const ( + minScryptSha256SaltSize = 16 // size in bytes == 128 bits + + // ScryptAlgorithm is the key for the scrypt algorithm. + ScryptAlgorithm = "scrypt-65536-8-1" + + // Legacy hash version salt length. + V1SaltLength = 32 + + // Legacy hash version system translates to KeyDerivationAlgorithm. + HashVersion1 = 1 // this translates to Scrypt KeyDerivationAlgorithm + +) func init() { - RegisterKeyDerivationFunc(ScryptAlgorithm, func(password string, salt []byte) ([]byte, error) { - if len(salt) < minScryptSha256SaltSize { - return nil, errors.Errorf("required salt size is atleast %d bytes", minPbkdfSha256SaltSize) - } - - //nolint:gomnd - return scrypt.Key([]byte(password), salt, 65536, 8, 1, MasterKeyLength) + RegisterKeyDerivers(ScryptAlgorithm, &scryptKeyDeriver{ + n: 65536, //nolint:gomnd + r: 8, //nolint:gomnd + p: 1, + recommendedSaltLength: V1SaltLength, + minSaltLength: minScryptSha256SaltSize, }) } + +type scryptKeyDeriver struct { + // n scryptCostParameterN is scrypt's CPU/memory cost parameter. + n int + // r scryptCostParameterR is scrypt's work factor. + r int + // p scryptCostParameterP is scrypt's parallelization parameter. + p int + + recommendedSaltLength int + minSaltLength int +} + +func (s *scryptKeyDeriver) DeriveKeyFromPassword(password string, salt []byte) ([]byte, error) { + if len(salt) < s.minSaltLength { + return nil, errors.Errorf("required salt size is atleast %d bytes", s.minSaltLength) + } + //nolint:wrapcheck + return scrypt.Key([]byte(password), salt, s.n, s.r, s.p, MasterKeyLength) +} + +func (s *scryptKeyDeriver) RecommendedSaltLength() int { + return s.recommendedSaltLength +} diff --git a/internal/user/user_profile.go b/internal/user/user_profile.go index 709eaaf05..53659cdf9 100644 --- a/internal/user/user_profile.go +++ b/internal/user/user_profile.go @@ -1,6 +1,9 @@ package user import ( + "math/rand" + + "github.com/kopia/kopia/internal/crypto" "github.com/kopia/kopia/repo/manifest" ) @@ -8,31 +11,36 @@ type Profile struct { ManifestID manifest.ID `json:"-"` - Username string `json:"username"` - PasswordHashVersion int `json:"passwordHashVersion"` // indicates how password is hashed - PasswordHash []byte `json:"passwordHash"` + Username string `json:"username"` + PasswordHashVersion int `json:"passwordHashVersion,omitempty"` // indicates how password is hashed, deprecated in favor of KeyDerivationAlgorithm + KeyDerivationAlgorithm string `json:"keyDerivationAlgorithm,omitempty"` + PasswordHash []byte `json:"passwordHash"` } // SetPassword changes the password for a user profile. -func (p *Profile) SetPassword(password, keyDerivationAlgorithm string) error { - return p.setPasswordV1(password, keyDerivationAlgorithm) +func (p *Profile) SetPassword(password string) error { + return p.setPassword(password) } // IsValidPassword determines whether the password is valid for a given user. -func (p *Profile) IsValidPassword(password, keyDerivationAlgorithm string) bool { +func (p *Profile) IsValidPassword(password string) bool { if p == nil { + algorithms := crypto.AllowedKeyDerivationAlgorithms() // if the user is invalid, return false but use the same amount of time as when we // compare against valid user to avoid revealing whether the user account exists. - isValidPasswordV1(password, dummyV1HashThatNeverMatchesAnyPassword, keyDerivationAlgorithm) + isValidPassword(password, dummyV1HashThatNeverMatchesAnyPassword, algorithms[rand.Intn(len(algorithms))]) //nolint:gosec return false } - switch p.PasswordHashVersion { - case hashVersion1: - return isValidPasswordV1(password, p.PasswordHash, keyDerivationAlgorithm) - - default: - return false + // Legacy case where password hash version is set + if p.PasswordHashVersion == crypto.HashVersion1 { + return isValidPassword(password, p.PasswordHash, crypto.ScryptAlgorithm) } + + if p.KeyDerivationAlgorithm != "" { + return isValidPassword(password, p.PasswordHash, p.KeyDerivationAlgorithm) + } + + return false } diff --git a/internal/user/user_profile_hash_v1.go b/internal/user/user_profile_hash_v1.go index ab8c3514d..4a11757b9 100644 --- a/internal/user/user_profile_hash_v1.go +++ b/internal/user/user_profile_hash_v1.go @@ -10,31 +10,37 @@ "github.com/kopia/kopia/internal/crypto" ) -// parameters for v1 hashing. -const ( - hashVersion1 = 1 - - v1SaltLength = 32 -) - //nolint:gochecknoglobals -var dummyV1HashThatNeverMatchesAnyPassword = make([]byte, crypto.MasterKeyLength+v1SaltLength) +var dummyV1HashThatNeverMatchesAnyPassword = make([]byte, crypto.MasterKeyLength+crypto.V1SaltLength) -func (p *Profile) setPasswordV1(password, keyDerivationAlgorithm string) error { - salt := make([]byte, v1SaltLength) +func (p *Profile) setPassword(password string) error { + keyDerivationAlgorithm := p.KeyDerivationAlgorithm + if keyDerivationAlgorithm == "" { + if p.PasswordHashVersion == 0 { + return errors.New("key derivation algorithm and password hash version not set") + } + // Setup to handle legacy hashVersion. + if p.PasswordHashVersion == crypto.HashVersion1 { + keyDerivationAlgorithm = crypto.ScryptAlgorithm + } + } + + saltLength, err := crypto.RecommendedSaltLength(keyDerivationAlgorithm) + if err != nil { + return errors.Wrap(err, "error getting recommended salt length") + } + + salt := make([]byte, saltLength) if _, err := io.ReadFull(rand.Reader, salt); err != nil { return errors.Wrap(err, "error generating salt") } - var err error - - p.PasswordHashVersion = 1 - p.PasswordHash, err = computePasswordHashV1(password, salt, keyDerivationAlgorithm) + p.PasswordHash, err = computePasswordHash(password, salt, keyDerivationAlgorithm) return err } -func computePasswordHashV1(password string, salt []byte, keyDerivationAlgorithm string) ([]byte, error) { +func computePasswordHash(password string, salt []byte, keyDerivationAlgorithm string) ([]byte, error) { key, err := crypto.DeriveKeyFromPassword(password, salt, keyDerivationAlgorithm) if err != nil { return nil, errors.Wrap(err, "error deriving key from password") @@ -45,14 +51,19 @@ func computePasswordHashV1(password string, salt []byte, keyDerivationAlgorithm return payload, nil } -func isValidPasswordV1(password string, hashedPassword []byte, keyDerivationAlgorithm string) bool { - if len(hashedPassword) != v1SaltLength+crypto.MasterKeyLength { +func isValidPassword(password string, hashedPassword []byte, keyDerivationAlgorithm string) bool { + saltLength, err := crypto.RecommendedSaltLength(keyDerivationAlgorithm) + if err != nil { + panic(err) + } + + if len(hashedPassword) != saltLength+crypto.MasterKeyLength { return false } - salt := hashedPassword[0:v1SaltLength] + salt := hashedPassword[0:saltLength] - h, err := computePasswordHashV1(password, salt, keyDerivationAlgorithm) + h, err := computePasswordHash(password, salt, keyDerivationAlgorithm) if err != nil { panic(err) } diff --git a/internal/user/user_profile_test.go b/internal/user/user_profile_test.go index 1497f25cb..e9088ad55 100644 --- a/internal/user/user_profile_test.go +++ b/internal/user/user_profile_test.go @@ -7,20 +7,56 @@ "github.com/kopia/kopia/internal/user" ) -func TestUserProfile(t *testing.T) { - p := &user.Profile{} +func TestLegacyUserProfile(t *testing.T) { + p := &user.Profile{ + PasswordHashVersion: 1, // hashVersion1 + } - if p.IsValidPassword("bar", crypto.DefaultKeyDerivationAlgorithm) { + if p.IsValidPassword("bar") { t.Fatalf("password unexpectedly valid!") } - p.SetPassword("foo", crypto.DefaultKeyDerivationAlgorithm) + p.SetPassword("foo") - if !p.IsValidPassword("foo", crypto.DefaultKeyDerivationAlgorithm) { + if !p.IsValidPassword("foo") { t.Fatalf("password not valid!") } - if p.IsValidPassword("bar", crypto.DefaultKeyDerivationAlgorithm) { + if p.IsValidPassword("bar") { + t.Fatalf("password unexpectedly valid!") + } + + // Setting the key derivation to scrypt and unsetting PasswordHashVersion + // Legacy profile should translate to scrypt + p.KeyDerivationAlgorithm = crypto.ScryptAlgorithm + p.PasswordHashVersion = 0 + if !p.IsValidPassword("foo") { + t.Fatalf("password not valid!") + } +} + +func TestUserProfile(t *testing.T) { + p := &user.Profile{ + KeyDerivationAlgorithm: crypto.ScryptAlgorithm, + } + + if p.IsValidPassword("bar") { + t.Fatalf("password unexpectedly valid!") + } + + p.SetPassword("foo") + + if !p.IsValidPassword("foo") { + t.Fatalf("password not valid!") + } + + if p.IsValidPassword("bar") { + t.Fatalf("password unexpectedly valid!") + } + + // Different key derivation algorithm besides the original should fail + p.KeyDerivationAlgorithm = crypto.Pbkdf2Algorithm + if p.IsValidPassword("foo") { t.Fatalf("password unexpectedly valid!") } } @@ -28,9 +64,15 @@ func TestUserProfile(t *testing.T) { func TestBadKeyDerivationAlgorithmPanic(t *testing.T) { defer func() { _ = recover() }() func() { - p := &user.Profile{} - p.SetPassword("foo", crypto.DefaultKeyDerivationAlgorithm) - p.IsValidPassword("foo", "bad algorithm") + // mock a valid password + p := &user.Profile{ + KeyDerivationAlgorithm: crypto.ScryptAlgorithm, + } + p.SetPassword("foo") + // Assume the key derivation algorithm is bad. This will cause + // a panic when validating + p.KeyDerivationAlgorithm = "some bad algorithm" + p.IsValidPassword("foo") }() t.Errorf("should have panicked") } @@ -38,7 +80,7 @@ func() { func TestNilUserProfile(t *testing.T) { var p *user.Profile - if p.IsValidPassword("bar", crypto.DefaultKeyDerivationAlgorithm) { + if p.IsValidPassword("bar") { t.Fatalf("password unexpectedly valid!") } } @@ -51,7 +93,7 @@ func TestInvalidPasswordHash(t *testing.T) { for _, tc := range cases { p := &user.Profile{PasswordHash: tc} - if p.IsValidPassword("some-password", crypto.DefaultKeyDerivationAlgorithm) { + if p.IsValidPassword("some-password") { t.Fatalf("password unexpectedly valid for %v", tc) } }