diff --git a/internal/blobtesting/asserts.go b/internal/blobtesting/asserts.go index eb7b956e3..45fe473b0 100644 --- a/internal/blobtesting/asserts.go +++ b/internal/blobtesting/asserts.go @@ -113,6 +113,19 @@ func AssertGetBlobNotFound(ctx context.Context, t *testing.T, s blob.Storage, bl } } +// AssertInvalidCredentials asserts that GetBlob() for specified blobID returns ErrInvalidCredentials. +func AssertInvalidCredentials(ctx context.Context, t *testing.T, s blob.Storage, blobID blob.ID) { + t.Helper() + + var b gather.WriteBuffer + defer b.Close() + + err := s.GetBlob(ctx, blobID, 0, -1, &b) + if !errors.Is(err, blob.ErrInvalidCredentials) { + t.Fatalf("GetBlob(%v) returned %v but expected ErrInvalidCredentials", blobID, err) + } +} + // AssertGetMetadataNotFound asserts that GetMetadata() for specified blobID returns ErrNotFound. func AssertGetMetadataNotFound(ctx context.Context, t *testing.T, s blob.Storage, blobID blob.ID) { t.Helper() diff --git a/repo/blob/retrying/retrying_storage.go b/repo/blob/retrying/retrying_storage.go index c3f95c79e..3ce463223 100644 --- a/repo/blob/retrying/retrying_storage.go +++ b/repo/blob/retrying/retrying_storage.go @@ -71,6 +71,9 @@ func isRetriable(err error) bool { case errors.Is(err, blob.ErrSetTimeUnsupported): return false + case errors.Is(err, blob.ErrInvalidCredentials): + return false + case errors.Is(err, blob.ErrUnsupportedPutBlobOption): return false diff --git a/repo/blob/s3/s3_storage.go b/repo/blob/s3/s3_storage.go index 3903dc10f..f4c9e1607 100644 --- a/repo/blob/s3/s3_storage.go +++ b/repo/blob/s3/s3_storage.go @@ -83,9 +83,17 @@ func (s *s3Storage) getBlobWithVersion(ctx context.Context, b blob.ID, version s return blob.EnsureLengthExactly(output.Length(), length) } +func isInvalidCredentials(err error) bool { + return err != nil && strings.Contains(err.Error(), blob.InvalidCredentialsErrStr) +} + func translateError(err error) error { var me minio.ErrorResponse + if isInvalidCredentials(err) { + return blob.ErrInvalidCredentials + } + if errors.As(err, &me) { switch me.StatusCode { case http.StatusOK: @@ -171,6 +179,10 @@ func (s *s3Storage) putBlob(ctx context.Context, b blob.ID, data blob.Bytes, opt Mode: retentionMode, }) + if isInvalidCredentials(err) { + return versionMetadata{}, blob.ErrInvalidCredentials + } + var er minio.ErrorResponse if errors.As(err, &er) && er.Code == "InvalidRequest" && strings.Contains(strings.ToLower(er.Message), "content-md5") { @@ -224,6 +236,10 @@ func (s *s3Storage) ListBlobs(ctx context.Context, prefix blob.ID, callback func }) for o := range oi { if err := o.Err; err != nil { + if isInvalidCredentials(err) { + return blob.ErrInvalidCredentials + } + return err } @@ -292,12 +308,16 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) { } func newStorage(ctx context.Context, opt *Options) (*s3Storage, error) { + return newStorageWithCredentials(ctx, credentials.NewStaticV4(opt.AccessKeyID, opt.SecretAccessKey, opt.SessionToken), opt) +} + +func newStorageWithCredentials(ctx context.Context, creds *credentials.Credentials, opt *Options) (*s3Storage, error) { if opt.BucketName == "" { return nil, errors.New("bucket name must be specified") } minioOpts := &minio.Options{ - Creds: credentials.NewStaticV4(opt.AccessKeyID, opt.SecretAccessKey, opt.SessionToken), + Creds: creds, Secure: !opt.DoNotUseTLS, Region: opt.Region, } diff --git a/repo/blob/s3/s3_storage_test.go b/repo/blob/s3/s3_storage_test.go index 2d45a1e63..dc6ed849e 100644 --- a/repo/blob/s3/s3_storage_test.go +++ b/repo/blob/s3/s3_storage_test.go @@ -10,6 +10,7 @@ "os" "path/filepath" "strings" + "sync/atomic" "testing" "time" @@ -30,6 +31,7 @@ "github.com/kopia/kopia/internal/timetrack" "github.com/kopia/kopia/internal/tlsutil" "github.com/kopia/kopia/repo/blob" + "github.com/kopia/kopia/repo/blob/retrying" ) const ( @@ -42,7 +44,8 @@ minioBucketName = "my-bucket" // we use ephemeral minio for each test so this does not need to be unique // default aws S3 endpoint. - awsEndpoint = "s3.amazonaws.com" + awsEndpoint = "s3.amazonaws.com" + awsStsEndpointUSWest2 = "https://sts.us-west-2.amazonaws.com" // env vars need to be set to execute TestS3StorageAWS. testEndpointEnv = "KOPIA_S3_TEST_ENDPOINT" @@ -51,6 +54,7 @@ testBucketEnv = "KOPIA_S3_TEST_BUCKET" testLockedBucketEnv = "KOPIA_S3_TEST_LOCKED_BUCKET" testRegionEnv = "KOPIA_S3_TEST_REGION" + testRoleEnv = "KOPIA_S3_TEST_ROLE" // additional env vars need to be set to execute TestS3StorageAWSSTS. testSTSAccessKeyIDEnv = "KOPIA_S3_TEST_STS_ACCESS_KEY_ID" testSTSSecretAccessKeyEnv = "KOPIA_S3_TEST_STS_SECRET_ACCESS_KEY" @@ -126,6 +130,40 @@ func getProviderOptions(tb testing.TB, envName string) *Options { return &o } +// verifyInvalidCredentialsForGetBlob verifies that the invalid credentials +// error is returned by GetBlob. +// nolint:thelper +func verifyInvalidCredentialsForGetBlob(ctx context.Context, t *testing.T, r blob.Storage) { + blocks := []struct { + blk blob.ID + contents []byte + }{ + {blk: "abcdbbf4f0507d054ed5a80a5b65086f602b", contents: []byte{}}, + {blk: "zxce0e35630770c54668a8cfb4e414c6bf8f", contents: []byte{1}}, + } + + for _, b := range blocks { + blobtesting.AssertInvalidCredentials(ctx, t, r, b.blk) + } +} + +// verifyBlobNotFoundForGetBlob verifies that the ErrBlobNotFound +// error is returned by GetBlob. +// nolint:thelper +func verifyBlobNotFoundForGetBlob(ctx context.Context, t *testing.T, r blob.Storage) { + blocks := []struct { + blk blob.ID + contents []byte + }{ + {blk: "abcdbbf4f0507d054ed5a80a5b65086f602b", contents: []byte{}}, + {blk: "zxce0e35630770c54668a8cfb4e414c6bf8f", contents: []byte{1}}, + } + + for _, b := range blocks { + blobtesting.AssertGetBlobNotFound(ctx, t, r, b.blk) + } +} + func TestS3StorageProviders(t *testing.T) { t.Parallel() @@ -243,6 +281,70 @@ func TestS3StorageRetentionLockedBucket(t *testing.T) { }) } +func TestTokenExpiration(t *testing.T) { + t.Parallel() + testutil.ProviderTest(t) + + awsAccessKeyID := getEnv(testAccessKeyIDEnv, "") + awsSecretAccessKeyID := getEnv(testSecretAccessKeyEnv, "") + bucketName := getEnvOrSkip(t, testBucketEnv) + region := getEnvOrSkip(t, testRegionEnv) + role := getEnvOrSkip(t, testRoleEnv) + + // Get the credentials and custom provider + creds, customProvider := customCredentialsAndProvider(awsAccessKeyID, awsSecretAccessKeyID, role, region) + + // Verify that the credentials can be used to get a new value + val, err := creds.Get() + if err != nil { + t.Fatalf("err: %v", err) + } + + createBucket(t, &Options{ + Endpoint: awsEndpoint, + AccessKeyID: awsAccessKeyID, + SecretAccessKey: awsSecretAccessKeyID, + BucketName: bucketName, + Region: region, + DoNotUseTLS: true, + }) + + require.NotEqual(t, awsAccessKeyID, val.AccessKeyID) + require.NotEqual(t, awsSecretAccessKeyID, val.SecretAccessKey) + + // Create new storage using the credentials + ctx := testlogging.Context(t) + + st, err := newStorageWithCredentials(ctx, creds, &Options{ + Endpoint: awsEndpoint, + BucketName: bucketName, + Region: region, + DoNotUseTLS: true, + }) + if err != nil { + t.Fatalf("err: %v", err) + } + + rst := retrying.NewWrapper(st) + + // Since the session token is valid at this point + // we expect errors that indicate that the blob was not found. + // customProvider.expired is false at this point since the customProvider + // was initialized with false. + verifyBlobNotFoundForGetBlob(ctx, t, rst) + + // Atomic set the expired flag to true here to force token expiration. + // After this we expect to get token expiration errors. + customProvider.forceExpired.Store(true) + verifyInvalidCredentialsForGetBlob(ctx, t, rst) + + // Reset the expired flag and expire the credentials, so that a new valid token + // is obtained by the client. + creds.Expire() + customProvider.forceExpired.Store(false) + verifyBlobNotFoundForGetBlob(ctx, t, rst) +} + func TestS3StorageMinio(t *testing.T) { t.Parallel() testutil.ProviderTest(t) @@ -625,3 +727,72 @@ func createMinioSessionToken(t *testing.T, minioEndpoint, kopiaUserName, kopiaUs return *result.Credentials.AccessKeyId, *result.Credentials.SecretAccessKey, *result.Credentials.SessionToken } + +// customProvider is a custom provider based on minio's STSAssumeRole struct +// that implements the logic for retrieving +// credentials and checking if the credentials +// have expired. +// The expired field is used to allow the user of this +// provider to force expiration of the credentials. This causes +// the next call to Retrieve to return expired credentials. +type customProvider struct { + forceExpired atomic.Value + stsProvider miniocreds.STSAssumeRole +} + +const expiredSessionToken = "IQoJb3JpZ2luX2VjEBMaCXVzLXdlc3QtMiJIM" + + "EYCIQDCu87ZTm4eMNLRvcFgkYycknuxWz8yZ8PQaElWZWameAIhAMOQlDkUqO" + + "HEsoRqCYAF1anKEuhgdrC8x1KaqlAb81nsKpwCCDwQAxoMMDM2Nzc2MzQwMTA" + + "yIgy03tG3mSbTUIsW83kq+QFIl2JcsjOQn2pqVmobXRHhZLmHWhFA0ti99Myn" + + "JA5Hj2rp1aK1zhEcA650pocUkXldMMvZ0qSShGggeIy7+6Y9XE7JXZpo/QKna" + + "0TJXTcxcjdgmgLm4vdxJRtdMaDdXmx3gKPuti+ez211tVjJLTjKdGMUH8jQoA" + + "qLe6jvF3ARWODP0SySAO/q3Q/eQDtwdMf/fYBmRVOtIOzPV7obzCQ45PsJkcE" + + "Ae60XFO5C47gbwne4eSEiipKAAA4zCJAA9pfa1S++4il8eMifGc3XDjvddn9i" + + "A0/tNI8bjsbCF1t9VtVcvLcaK7MOvMrNeNztLO8GyNxgcv9uUC0w0+KtjwY6n" + + "AGTxeDWJUKBfXuc7CeUgpjuflTf4aAq+Gpe5T+m2FbStRMgk6uThtPiw53EUC" + + "w/1tyUNysTAn1bYffmLVhRU9CP86Hj23C01/IeLjXzSXAF8T6nv7nmAO50D7l" + + "RCcVWcntllxyL/sUZ7VbMr7xZxWWbilu8pVtQqTwwBxZO0rth8XftMzGQ5oyd" + + "82CdcwRB+t7K1LEmRErltbteGtM=" + +func (cp *customProvider) Retrieve() (miniocreds.Value, error) { + if cp.forceExpired.Load().(bool) { + return miniocreds.Value{ + AccessKeyID: "ASIAQREAKNKDBR4F5F2I", + SecretAccessKey: "EF82nKmZbnFETa96xxx1C3k20hG4Nw+2v+FBNjp3", + SessionToken: expiredSessionToken, + SignerType: miniocreds.SignatureV2, + }, nil + } + + return cp.stsProvider.Retrieve() +} + +func (cp *customProvider) IsExpired() bool { + return cp.forceExpired.Load().(bool) +} + +// customCredentialsAndProvider creates a custom provider and returns credentials +// using this provider. +func customCredentialsAndProvider(accessKey, secretKey, roleARN, region string) (*miniocreds.Credentials, *customProvider) { + opts := miniocreds.STSAssumeRoleOptions{ + AccessKey: accessKey, + SecretKey: secretKey, + Location: region, + RoleARN: roleARN, + RoleSessionName: "s3-test-session", + } + stsEndpoint := awsStsEndpointUSWest2 + cp := &customProvider{ + stsProvider: miniocreds.STSAssumeRole{ + Client: &http.Client{ + Transport: http.DefaultTransport, + }, + STSEndpoint: stsEndpoint, + Options: opts, + }, + } + // Initialize expired to false + cp.forceExpired.Store(false) + + return miniocreds.New(cp), cp +} diff --git a/repo/blob/storage.go b/repo/blob/storage.go index 8f33b2845..38330e6b7 100644 --- a/repo/blob/storage.go +++ b/repo/blob/storage.go @@ -17,6 +17,14 @@ // ErrInvalidRange is returned when the requested blob offset or length is invalid. var ErrInvalidRange = errors.Errorf("invalid blob offset or length") +// InvalidCredentialsErrStr is the error string returned by the provider +// when a token has expired. +const InvalidCredentialsErrStr = "The provided token has expired" + +// ErrInvalidCredentials is returned when the token used for +// authenticating with a storage provider has expired. +var ErrInvalidCredentials = errors.Errorf(InvalidCredentialsErrStr) + // ErrBlobAlreadyExists is returned when attempting to put a blob that already exists. var ErrBlobAlreadyExists = errors.New("blob already exists")