s3: remove session token on cross-host redirects

Add a redirect policy to the S3 HTTP client so X-Amz-Security-Token is
removed once a redirect chain crosses hosts. Keep stripping it on later
same-host hops in the same chain, since net/http copies headers from the
initial request for each redirect and can otherwise restore the token.

Preserve same-host redirect behavior, retain the standard redirect limit,
and add tests for cross-host, same-host, multi-hop, and redirect-loop cases.
This commit is contained in:
IceLocke
2026-05-28 10:33:42 +00:00
committed by Nick Craig-Wood
parent c96385c280
commit e7b1eb774c
2 changed files with 132 additions and 1 deletions

View File

@@ -1351,10 +1351,34 @@ func getClient(ctx context.Context, opt *Options) *http.Client {
}
})
return &http.Client{
Transport: t,
Transport: t,
CheckRedirect: s3CheckRedirect,
}
}
func s3CheckRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
if s3RedirectCrossesHost(req, via) {
req.Header.Del("X-Amz-Security-Token")
}
return nil
}
func s3RedirectCrossesHost(req *http.Request, via []*http.Request) bool {
if len(via) == 0 {
return false
}
host := via[0].URL.Host
for _, redirect := range via[1:] {
if redirect.URL.Host != host {
return true
}
}
return host != req.URL.Host
}
// Fixup the request if needed.
//
// Google Cloud Storage alters the Accept-Encoding header, which

View File

@@ -4,6 +4,7 @@ package s3
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
@@ -22,6 +23,112 @@ func SetupS3Test(t *testing.T) (context.Context, *Options, *http.Client) {
return ctx, opt, client
}
func TestClientRemovesSecurityTokenOnCrossHostRedirect(t *testing.T) {
ctx, _, client := SetupS3Test(t)
redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get("X-Amz-Security-Token"))
assert.Equal(t, "date", r.Header.Get("X-Amz-Date"))
w.WriteHeader(http.StatusOK)
}))
defer redirectServer.Close()
initialServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token"))
http.Redirect(w, r, redirectServer.URL, http.StatusTemporaryRedirect)
}))
defer initialServer.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, initialServer.URL, nil)
require.NoError(t, err)
req.Header.Set("X-Amz-Security-Token", "token")
req.Header.Set("X-Amz-Date", "date")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NoError(t, resp.Body.Close())
}
func TestClientDoesNotRestoreSecurityTokenAfterCrossHostRedirect(t *testing.T) {
ctx, _, client := SetupS3Test(t)
redirectServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/middle":
assert.Empty(t, r.Header.Get("X-Amz-Security-Token"))
assert.Equal(t, "date", r.Header.Get("X-Amz-Date"))
http.Redirect(w, r, "/final", http.StatusTemporaryRedirect)
case "/final":
assert.Empty(t, r.Header.Get("X-Amz-Security-Token"))
assert.Equal(t, "date", r.Header.Get("X-Amz-Date"))
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
}))
defer redirectServer.Close()
initialServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token"))
http.Redirect(w, r, redirectServer.URL+"/middle", http.StatusTemporaryRedirect)
}))
defer initialServer.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, initialServer.URL, nil)
require.NoError(t, err)
req.Header.Set("X-Amz-Security-Token", "token")
req.Header.Set("X-Amz-Date", "date")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NoError(t, resp.Body.Close())
}
func TestClientKeepsSecurityTokenOnSameHostRedirect(t *testing.T) {
ctx, _, client := SetupS3Test(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/":
assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token"))
http.Redirect(w, r, "/redirected", http.StatusTemporaryRedirect)
case "/redirected":
assert.Equal(t, "token", r.Header.Get("X-Amz-Security-Token"))
w.WriteHeader(http.StatusOK)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil)
require.NoError(t, err)
req.Header.Set("X-Amz-Security-Token", "token")
resp, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.NoError(t, resp.Body.Close())
}
func TestClientStopsAfterTenRedirects(t *testing.T) {
_, _, client := SetupS3Test(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
}))
defer server.Close()
resp, err := client.Get(server.URL)
if resp != nil {
_ = resp.Body.Close()
}
require.Error(t, err)
assert.Contains(t, err.Error(), "stopped after 10 redirects")
}
// TestIntegration runs integration tests against the remote
func TestIntegration(t *testing.T) {
opt := &fstests.Opt{