From 2751384cef6f9821e11a1e602656b277df46c6f0 Mon Sep 17 00:00:00 2001 From: Leander Beernaert Date: Mon, 27 Mar 2023 09:33:45 +0200 Subject: [PATCH] fix(GODT-2514): Apply Retry-After to 503 status Apply the same retry-after code for 429 replies request to 503 replies. --- manager_test.go | 35 +++++++++++++++++++++++++++++++++++ response.go | 4 ++-- server/rate_limit.go | 10 +++++++--- server/router.go | 2 +- server/server_builder.go | 20 +++++++++++++++----- 5 files changed, 60 insertions(+), 11 deletions(-) diff --git a/manager_test.go b/manager_test.go index 9e1a3b7..02e76f5 100644 --- a/manager_test.go +++ b/manager_test.go @@ -102,6 +102,41 @@ func TestHandleTooManyRequests(t *testing.T) { } } +func TestHandleTooManyRequests503(t *testing.T) { + // Create a server with a rate limit of 1 request per second. + s := server.New(server.WithRateLimitAndCustomStatusCode(1, time.Second, http.StatusServiceUnavailable)) + defer s.Close() + + var calls []server.Call + + // Watch the calls made. + s.AddCallWatcher(func(call server.Call) { + calls = append(calls, call) + }) + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + // Make five calls; they should all succeed, but will be rate limited. + for i := 0; i < 5; i++ { + require.NoError(t, m.Ping(context.Background())) + } + + // After each 503 response, we should wait at least the requested duration before making the next request. + for idx, call := range calls { + if call.Status == http.StatusServiceUnavailable { + after, err := strconv.Atoi(call.ResponseHeader.Get("Retry-After")) + require.NoError(t, err) + + // The next call should be made after the requested duration. + require.True(t, calls[idx+1].Time.After(call.Time.Add(time.Duration(after)*time.Second))) + } + } +} + func TestHandleTooManyRequests_Malformed(t *testing.T) { var calls []time.Time diff --git a/response.go b/response.go index 0b8b868..0d2dbb7 100644 --- a/response.go +++ b/response.go @@ -114,7 +114,7 @@ func updateTime(_ *resty.Client, res *resty.Response) error { // nolint:gosec func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) { // 0 and no error means default behaviour which is exponential backoff with jitter. - if res.StatusCode() != http.StatusTooManyRequests { + if res.StatusCode() != http.StatusTooManyRequests && res.StatusCode() != http.StatusServiceUnavailable { return 0, nil } @@ -139,7 +139,7 @@ func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error } func catchTooManyRequests(res *resty.Response, _ error) bool { - return res.StatusCode() == http.StatusTooManyRequests + return res.StatusCode() == http.StatusTooManyRequests || res.StatusCode() == http.StatusServiceUnavailable } func catchDialError(res *resty.Response, err error) bool { diff --git a/server/rate_limit.go b/server/rate_limit.go index 29bb59e..d4e557f 100644 --- a/server/rate_limit.go +++ b/server/rate_limit.go @@ -22,12 +22,16 @@ type rateLimiter struct { // countLock is a mutex for the callCount. countLock sync.Mutex + + // statusCode to reply with + statusCode int } -func newRateLimiter(limit int, window time.Duration) *rateLimiter { +func newRateLimiter(limit int, window time.Duration, statusCode int) *rateLimiter { return &rateLimiter{ - limit: limit, - window: window, + limit: limit, + window: window, + statusCode: statusCode, } } diff --git a/server/router.go b/server/router.go index 9f334c8..a2da95a 100644 --- a/server/router.go +++ b/server/router.go @@ -202,7 +202,7 @@ func (s *Server) applyRateLimit() gin.HandlerFunc { if wait := s.rateLimit.exceeded(); wait > 0 { c.Header("Retry-After", strconv.Itoa(int(wait.Seconds()))) - c.AbortWithStatus(http.StatusTooManyRequests) + c.AbortWithStatus(s.rateLimit.statusCode) } } } diff --git a/server/server_builder.go b/server/server_builder.go index 1fbd229..516f8e8 100644 --- a/server/server_builder.go +++ b/server/server_builder.go @@ -186,18 +186,28 @@ func (opt withAuthCache) config(builder *serverBuilder) { func WithRateLimit(limit int, window time.Duration) Option { return &withRateLimit{ - limit: limit, - window: window, + limit: limit, + window: window, + statusCode: http.StatusTooManyRequests, + } +} + +func WithRateLimitAndCustomStatusCode(limit int, window time.Duration, code int) Option { + return &withRateLimit{ + limit: limit, + window: window, + statusCode: code, } } type withRateLimit struct { - limit int - window time.Duration + limit int + statusCode int + window time.Duration } func (opt withRateLimit) config(builder *serverBuilder) { - builder.rateLimiter = newRateLimiter(opt.limit, opt.window) + builder.rateLimiter = newRateLimiter(opt.limit, opt.window, opt.statusCode) } func WithProxyTransport(transport *http.Transport) Option {