mirror of
https://github.com/kopia/kopia.git
synced 2026-05-16 10:44:40 -04:00
feat(providers): treat token expiration errors as non-retryable (#1675)
Token expiration errors should be treated as non-retryable errors.
This commit is contained in:
@@ -113,6 +113,19 @@ func AssertGetBlobNotFound(ctx context.Context, t *testing.T, s blob.Storage, bl
|
||||
}
|
||||
}
|
||||
|
||||
// AssertInvalidCredentials asserts that GetBlob() for specified blobID returns ErrInvalidCredentials.
|
||||
func AssertInvalidCredentials(ctx context.Context, t *testing.T, s blob.Storage, blobID blob.ID) {
|
||||
t.Helper()
|
||||
|
||||
var b gather.WriteBuffer
|
||||
defer b.Close()
|
||||
|
||||
err := s.GetBlob(ctx, blobID, 0, -1, &b)
|
||||
if !errors.Is(err, blob.ErrInvalidCredentials) {
|
||||
t.Fatalf("GetBlob(%v) returned %v but expected ErrInvalidCredentials", blobID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertGetMetadataNotFound asserts that GetMetadata() for specified blobID returns ErrNotFound.
|
||||
func AssertGetMetadataNotFound(ctx context.Context, t *testing.T, s blob.Storage, blobID blob.ID) {
|
||||
t.Helper()
|
||||
|
||||
@@ -71,6 +71,9 @@ func isRetriable(err error) bool {
|
||||
case errors.Is(err, blob.ErrSetTimeUnsupported):
|
||||
return false
|
||||
|
||||
case errors.Is(err, blob.ErrInvalidCredentials):
|
||||
return false
|
||||
|
||||
case errors.Is(err, blob.ErrUnsupportedPutBlobOption):
|
||||
return false
|
||||
|
||||
|
||||
@@ -83,9 +83,17 @@ func (s *s3Storage) getBlobWithVersion(ctx context.Context, b blob.ID, version s
|
||||
return blob.EnsureLengthExactly(output.Length(), length)
|
||||
}
|
||||
|
||||
func isInvalidCredentials(err error) bool {
|
||||
return err != nil && strings.Contains(err.Error(), blob.InvalidCredentialsErrStr)
|
||||
}
|
||||
|
||||
func translateError(err error) error {
|
||||
var me minio.ErrorResponse
|
||||
|
||||
if isInvalidCredentials(err) {
|
||||
return blob.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if errors.As(err, &me) {
|
||||
switch me.StatusCode {
|
||||
case http.StatusOK:
|
||||
@@ -171,6 +179,10 @@ func (s *s3Storage) putBlob(ctx context.Context, b blob.ID, data blob.Bytes, opt
|
||||
Mode: retentionMode,
|
||||
})
|
||||
|
||||
if isInvalidCredentials(err) {
|
||||
return versionMetadata{}, blob.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
var er minio.ErrorResponse
|
||||
|
||||
if errors.As(err, &er) && er.Code == "InvalidRequest" && strings.Contains(strings.ToLower(er.Message), "content-md5") {
|
||||
@@ -224,6 +236,10 @@ func (s *s3Storage) ListBlobs(ctx context.Context, prefix blob.ID, callback func
|
||||
})
|
||||
for o := range oi {
|
||||
if err := o.Err; err != nil {
|
||||
if isInvalidCredentials(err) {
|
||||
return blob.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -292,12 +308,16 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) {
|
||||
}
|
||||
|
||||
func newStorage(ctx context.Context, opt *Options) (*s3Storage, error) {
|
||||
return newStorageWithCredentials(ctx, credentials.NewStaticV4(opt.AccessKeyID, opt.SecretAccessKey, opt.SessionToken), opt)
|
||||
}
|
||||
|
||||
func newStorageWithCredentials(ctx context.Context, creds *credentials.Credentials, opt *Options) (*s3Storage, error) {
|
||||
if opt.BucketName == "" {
|
||||
return nil, errors.New("bucket name must be specified")
|
||||
}
|
||||
|
||||
minioOpts := &minio.Options{
|
||||
Creds: credentials.NewStaticV4(opt.AccessKeyID, opt.SecretAccessKey, opt.SessionToken),
|
||||
Creds: creds,
|
||||
Secure: !opt.DoNotUseTLS,
|
||||
Region: opt.Region,
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -30,6 +31,7 @@
|
||||
"github.com/kopia/kopia/internal/timetrack"
|
||||
"github.com/kopia/kopia/internal/tlsutil"
|
||||
"github.com/kopia/kopia/repo/blob"
|
||||
"github.com/kopia/kopia/repo/blob/retrying"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -42,7 +44,8 @@
|
||||
minioBucketName = "my-bucket" // we use ephemeral minio for each test so this does not need to be unique
|
||||
|
||||
// default aws S3 endpoint.
|
||||
awsEndpoint = "s3.amazonaws.com"
|
||||
awsEndpoint = "s3.amazonaws.com"
|
||||
awsStsEndpointUSWest2 = "https://sts.us-west-2.amazonaws.com"
|
||||
|
||||
// env vars need to be set to execute TestS3StorageAWS.
|
||||
testEndpointEnv = "KOPIA_S3_TEST_ENDPOINT"
|
||||
@@ -51,6 +54,7 @@
|
||||
testBucketEnv = "KOPIA_S3_TEST_BUCKET"
|
||||
testLockedBucketEnv = "KOPIA_S3_TEST_LOCKED_BUCKET"
|
||||
testRegionEnv = "KOPIA_S3_TEST_REGION"
|
||||
testRoleEnv = "KOPIA_S3_TEST_ROLE"
|
||||
// additional env vars need to be set to execute TestS3StorageAWSSTS.
|
||||
testSTSAccessKeyIDEnv = "KOPIA_S3_TEST_STS_ACCESS_KEY_ID"
|
||||
testSTSSecretAccessKeyEnv = "KOPIA_S3_TEST_STS_SECRET_ACCESS_KEY"
|
||||
@@ -126,6 +130,40 @@ func getProviderOptions(tb testing.TB, envName string) *Options {
|
||||
return &o
|
||||
}
|
||||
|
||||
// verifyInvalidCredentialsForGetBlob verifies that the invalid credentials
|
||||
// error is returned by GetBlob.
|
||||
// nolint:thelper
|
||||
func verifyInvalidCredentialsForGetBlob(ctx context.Context, t *testing.T, r blob.Storage) {
|
||||
blocks := []struct {
|
||||
blk blob.ID
|
||||
contents []byte
|
||||
}{
|
||||
{blk: "abcdbbf4f0507d054ed5a80a5b65086f602b", contents: []byte{}},
|
||||
{blk: "zxce0e35630770c54668a8cfb4e414c6bf8f", contents: []byte{1}},
|
||||
}
|
||||
|
||||
for _, b := range blocks {
|
||||
blobtesting.AssertInvalidCredentials(ctx, t, r, b.blk)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyBlobNotFoundForGetBlob verifies that the ErrBlobNotFound
|
||||
// error is returned by GetBlob.
|
||||
// nolint:thelper
|
||||
func verifyBlobNotFoundForGetBlob(ctx context.Context, t *testing.T, r blob.Storage) {
|
||||
blocks := []struct {
|
||||
blk blob.ID
|
||||
contents []byte
|
||||
}{
|
||||
{blk: "abcdbbf4f0507d054ed5a80a5b65086f602b", contents: []byte{}},
|
||||
{blk: "zxce0e35630770c54668a8cfb4e414c6bf8f", contents: []byte{1}},
|
||||
}
|
||||
|
||||
for _, b := range blocks {
|
||||
blobtesting.AssertGetBlobNotFound(ctx, t, r, b.blk)
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3StorageProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -243,6 +281,70 @@ func TestS3StorageRetentionLockedBucket(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
testutil.ProviderTest(t)
|
||||
|
||||
awsAccessKeyID := getEnv(testAccessKeyIDEnv, "")
|
||||
awsSecretAccessKeyID := getEnv(testSecretAccessKeyEnv, "")
|
||||
bucketName := getEnvOrSkip(t, testBucketEnv)
|
||||
region := getEnvOrSkip(t, testRegionEnv)
|
||||
role := getEnvOrSkip(t, testRoleEnv)
|
||||
|
||||
// Get the credentials and custom provider
|
||||
creds, customProvider := customCredentialsAndProvider(awsAccessKeyID, awsSecretAccessKeyID, role, region)
|
||||
|
||||
// Verify that the credentials can be used to get a new value
|
||||
val, err := creds.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
createBucket(t, &Options{
|
||||
Endpoint: awsEndpoint,
|
||||
AccessKeyID: awsAccessKeyID,
|
||||
SecretAccessKey: awsSecretAccessKeyID,
|
||||
BucketName: bucketName,
|
||||
Region: region,
|
||||
DoNotUseTLS: true,
|
||||
})
|
||||
|
||||
require.NotEqual(t, awsAccessKeyID, val.AccessKeyID)
|
||||
require.NotEqual(t, awsSecretAccessKeyID, val.SecretAccessKey)
|
||||
|
||||
// Create new storage using the credentials
|
||||
ctx := testlogging.Context(t)
|
||||
|
||||
st, err := newStorageWithCredentials(ctx, creds, &Options{
|
||||
Endpoint: awsEndpoint,
|
||||
BucketName: bucketName,
|
||||
Region: region,
|
||||
DoNotUseTLS: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
rst := retrying.NewWrapper(st)
|
||||
|
||||
// Since the session token is valid at this point
|
||||
// we expect errors that indicate that the blob was not found.
|
||||
// customProvider.expired is false at this point since the customProvider
|
||||
// was initialized with false.
|
||||
verifyBlobNotFoundForGetBlob(ctx, t, rst)
|
||||
|
||||
// Atomic set the expired flag to true here to force token expiration.
|
||||
// After this we expect to get token expiration errors.
|
||||
customProvider.forceExpired.Store(true)
|
||||
verifyInvalidCredentialsForGetBlob(ctx, t, rst)
|
||||
|
||||
// Reset the expired flag and expire the credentials, so that a new valid token
|
||||
// is obtained by the client.
|
||||
creds.Expire()
|
||||
customProvider.forceExpired.Store(false)
|
||||
verifyBlobNotFoundForGetBlob(ctx, t, rst)
|
||||
}
|
||||
|
||||
func TestS3StorageMinio(t *testing.T) {
|
||||
t.Parallel()
|
||||
testutil.ProviderTest(t)
|
||||
@@ -625,3 +727,72 @@ func createMinioSessionToken(t *testing.T, minioEndpoint, kopiaUserName, kopiaUs
|
||||
|
||||
return *result.Credentials.AccessKeyId, *result.Credentials.SecretAccessKey, *result.Credentials.SessionToken
|
||||
}
|
||||
|
||||
// customProvider is a custom provider based on minio's STSAssumeRole struct
|
||||
// that implements the logic for retrieving
|
||||
// credentials and checking if the credentials
|
||||
// have expired.
|
||||
// The expired field is used to allow the user of this
|
||||
// provider to force expiration of the credentials. This causes
|
||||
// the next call to Retrieve to return expired credentials.
|
||||
type customProvider struct {
|
||||
forceExpired atomic.Value
|
||||
stsProvider miniocreds.STSAssumeRole
|
||||
}
|
||||
|
||||
const expiredSessionToken = "IQoJb3JpZ2luX2VjEBMaCXVzLXdlc3QtMiJIM" +
|
||||
"EYCIQDCu87ZTm4eMNLRvcFgkYycknuxWz8yZ8PQaElWZWameAIhAMOQlDkUqO" +
|
||||
"HEsoRqCYAF1anKEuhgdrC8x1KaqlAb81nsKpwCCDwQAxoMMDM2Nzc2MzQwMTA" +
|
||||
"yIgy03tG3mSbTUIsW83kq+QFIl2JcsjOQn2pqVmobXRHhZLmHWhFA0ti99Myn" +
|
||||
"JA5Hj2rp1aK1zhEcA650pocUkXldMMvZ0qSShGggeIy7+6Y9XE7JXZpo/QKna" +
|
||||
"0TJXTcxcjdgmgLm4vdxJRtdMaDdXmx3gKPuti+ez211tVjJLTjKdGMUH8jQoA" +
|
||||
"qLe6jvF3ARWODP0SySAO/q3Q/eQDtwdMf/fYBmRVOtIOzPV7obzCQ45PsJkcE" +
|
||||
"Ae60XFO5C47gbwne4eSEiipKAAA4zCJAA9pfa1S++4il8eMifGc3XDjvddn9i" +
|
||||
"A0/tNI8bjsbCF1t9VtVcvLcaK7MOvMrNeNztLO8GyNxgcv9uUC0w0+KtjwY6n" +
|
||||
"AGTxeDWJUKBfXuc7CeUgpjuflTf4aAq+Gpe5T+m2FbStRMgk6uThtPiw53EUC" +
|
||||
"w/1tyUNysTAn1bYffmLVhRU9CP86Hj23C01/IeLjXzSXAF8T6nv7nmAO50D7l" +
|
||||
"RCcVWcntllxyL/sUZ7VbMr7xZxWWbilu8pVtQqTwwBxZO0rth8XftMzGQ5oyd" +
|
||||
"82CdcwRB+t7K1LEmRErltbteGtM="
|
||||
|
||||
func (cp *customProvider) Retrieve() (miniocreds.Value, error) {
|
||||
if cp.forceExpired.Load().(bool) {
|
||||
return miniocreds.Value{
|
||||
AccessKeyID: "ASIAQREAKNKDBR4F5F2I",
|
||||
SecretAccessKey: "EF82nKmZbnFETa96xxx1C3k20hG4Nw+2v+FBNjp3",
|
||||
SessionToken: expiredSessionToken,
|
||||
SignerType: miniocreds.SignatureV2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return cp.stsProvider.Retrieve()
|
||||
}
|
||||
|
||||
func (cp *customProvider) IsExpired() bool {
|
||||
return cp.forceExpired.Load().(bool)
|
||||
}
|
||||
|
||||
// customCredentialsAndProvider creates a custom provider and returns credentials
|
||||
// using this provider.
|
||||
func customCredentialsAndProvider(accessKey, secretKey, roleARN, region string) (*miniocreds.Credentials, *customProvider) {
|
||||
opts := miniocreds.STSAssumeRoleOptions{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Location: region,
|
||||
RoleARN: roleARN,
|
||||
RoleSessionName: "s3-test-session",
|
||||
}
|
||||
stsEndpoint := awsStsEndpointUSWest2
|
||||
cp := &customProvider{
|
||||
stsProvider: miniocreds.STSAssumeRole{
|
||||
Client: &http.Client{
|
||||
Transport: http.DefaultTransport,
|
||||
},
|
||||
STSEndpoint: stsEndpoint,
|
||||
Options: opts,
|
||||
},
|
||||
}
|
||||
// Initialize expired to false
|
||||
cp.forceExpired.Store(false)
|
||||
|
||||
return miniocreds.New(cp), cp
|
||||
}
|
||||
|
||||
@@ -17,6 +17,14 @@
|
||||
// ErrInvalidRange is returned when the requested blob offset or length is invalid.
|
||||
var ErrInvalidRange = errors.Errorf("invalid blob offset or length")
|
||||
|
||||
// InvalidCredentialsErrStr is the error string returned by the provider
|
||||
// when a token has expired.
|
||||
const InvalidCredentialsErrStr = "The provided token has expired"
|
||||
|
||||
// ErrInvalidCredentials is returned when the token used for
|
||||
// authenticating with a storage provider has expired.
|
||||
var ErrInvalidCredentials = errors.Errorf(InvalidCredentialsErrStr)
|
||||
|
||||
// ErrBlobAlreadyExists is returned when attempting to put a blob that already exists.
|
||||
var ErrBlobAlreadyExists = errors.New("blob already exists")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user