diff --git a/backend/s3/s3.go b/backend/s3/s3.go index f59545d8a..490aac221 100644 --- a/backend/s3/s3.go +++ b/backend/s3/s3.go @@ -1351,10 +1351,34 @@ func getClient(ctx context.Context, opt *Options) *http.Client { } }) return &http.Client{ - Transport: t, + Transport: t, + CheckRedirect: s3CheckRedirect, } } +func s3CheckRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + if s3RedirectCrossesHost(req, via) { + req.Header.Del("X-Amz-Security-Token") + } + return nil +} + +func s3RedirectCrossesHost(req *http.Request, via []*http.Request) bool { + if len(via) == 0 { + return false + } + host := via[0].URL.Host + for _, redirect := range via[1:] { + if redirect.URL.Host != host { + return true + } + } + return host != req.URL.Host +} + // Fixup the request if needed. // // Google Cloud Storage alters the Accept-Encoding header, which diff --git a/backend/s3/s3_test.go b/backend/s3/s3_test.go index a262d1703..316fd79fd 100644 --- a/backend/s3/s3_test.go +++ b/backend/s3/s3_test.go @@ -4,6 +4,7 @@ package s3 import ( "context" "net/http" + "net/http/httptest" "testing" "time" @@ -22,6 +23,112 @@ func SetupS3Test(t *testing.T) (context.Context, *Options, *http.Client) { return ctx, opt, client } +func TestClientRemovesSecurityTokenOnCrossHostRedirect(t *testing.T) { + ctx, _, client := SetupS3Test(t) + + redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Amz-Security-Token")) + assert.Equal(t, "date", r.Header.Get("X-Amz-Date")) + w.WriteHeader(http.StatusOK) + })) + defer redirectServer.Close() + + initialServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token")) + http.Redirect(w, r, redirectServer.URL, http.StatusTemporaryRedirect) + })) + defer initialServer.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, initialServer.URL, nil) + require.NoError(t, err) + req.Header.Set("X-Amz-Security-Token", "token") + req.Header.Set("X-Amz-Date", "date") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestClientDoesNotRestoreSecurityTokenAfterCrossHostRedirect(t *testing.T) { + ctx, _, client := SetupS3Test(t) + + redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/middle": + assert.Empty(t, r.Header.Get("X-Amz-Security-Token")) + assert.Equal(t, "date", r.Header.Get("X-Amz-Date")) + http.Redirect(w, r, "/final", http.StatusTemporaryRedirect) + case "/final": + assert.Empty(t, r.Header.Get("X-Amz-Security-Token")) + assert.Equal(t, "date", r.Header.Get("X-Amz-Date")) + w.WriteHeader(http.StatusOK) + default: + http.NotFound(w, r) + } + })) + defer redirectServer.Close() + + initialServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token")) + http.Redirect(w, r, redirectServer.URL+"/middle", http.StatusTemporaryRedirect) + })) + defer initialServer.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, initialServer.URL, nil) + require.NoError(t, err) + req.Header.Set("X-Amz-Security-Token", "token") + req.Header.Set("X-Amz-Date", "date") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestClientKeepsSecurityTokenOnSameHostRedirect(t *testing.T) { + ctx, _, client := SetupS3Test(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/": + assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token")) + http.Redirect(w, r, "/redirected", http.StatusTemporaryRedirect) + case "/redirected": + assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token")) + w.WriteHeader(http.StatusOK) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + require.NoError(t, err) + req.Header.Set("X-Amz-Security-Token", "token") + + resp, err := client.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NoError(t, resp.Body.Close()) +} + +func TestClientStopsAfterTenRedirects(t *testing.T) { + _, _, client := SetupS3Test(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) + })) + defer server.Close() + + resp, err := client.Get(server.URL) + if resp != nil { + _ = resp.Body.Close() + } + require.Error(t, err) + assert.Contains(t, err.Error(), "stopped after 10 redirects") +} + // TestIntegration runs integration tests against the remote func TestIntegration(t *testing.T) { opt := &fstests.Opt{