Files
navidrome/server/throttle_backlog.go
Deluan 2a43c4683e chore: go fix
Signed-off-by: Deluan <deluan@navidrome.org>
2026-05-28 22:13:05 -03:00

150 lines
3.6 KiB
Go

package server
import (
"bytes"
"context"
"errors"
"maps"
"net/http"
"sync"
"time"
"github.com/go-chi/chi/v5/middleware"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
)
var (
ErrThrottleCapacityExceeded = errors.New("throttle: capacity exceeded")
ErrThrottleTimeout = errors.New("throttle: backlog timeout")
)
type requestThrottle struct {
tokens chan struct{}
backlogTokens chan struct{}
backlogTimeout time.Duration
}
// ThrottleBacklog creates a Chi-compatible middleware that limits concurrent
// request processing. Unlike Chi's ThrottleBacklog, it buffers the handler's
// response while holding the token, releases it, then flushes the buffer to
// the client with a write deadline. This prevents slow clients from holding
// throttle capacity.
//
// Because it buffers the entire response in memory, this middleware should only
// be used for endpoints that return small responses (e.g., artwork images). Do
// not use it for audio streaming or download endpoints.
func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler {
if limit <= 0 {
return func(next http.Handler) http.Handler { return next }
}
if !conf.Server.DevArtworkThrottleBuffered {
return middleware.ThrottleBacklog(limit, backlogLimit, backlogTimeout)
}
t := &requestThrottle{
tokens: make(chan struct{}, limit),
backlogTokens: make(chan struct{}, limit+backlogLimit),
backlogTimeout: backlogTimeout,
}
for range limit {
t.tokens <- struct{}{}
}
for range limit + backlogLimit {
t.backlogTokens <- struct{}{}
}
return t.handler
}
func (t *requestThrottle) handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
release, err := t.acquire(ctx)
if err != nil {
switch {
case errors.Is(err, ErrThrottleCapacityExceeded):
log.Warn(ctx, "Request throttle capacity exceeded", "path", r.URL.Path)
case errors.Is(err, ErrThrottleTimeout):
log.Warn(ctx, "Request throttle backlog timeout", "path", r.URL.Path)
}
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
buf := &bufferedResponseWriter{header: make(http.Header)}
func() {
defer release()
next.ServeHTTP(buf, r)
}()
maps.Copy(w.Header(), buf.header)
if buf.code > 0 {
w.WriteHeader(buf.code)
}
if _, err := w.Write(buf.body.Bytes()); err != nil {
log.Warn(ctx, "Error writing throttled response", err)
}
})
}
func (t *requestThrottle) acquire(ctx context.Context) (release func(), err error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-t.backlogTokens:
default:
return nil, ErrThrottleCapacityExceeded
}
select {
case <-t.tokens:
return t.releaseFunc(), nil
default:
}
timer := time.NewTimer(t.backlogTimeout)
select {
case <-timer.C:
t.backlogTokens <- struct{}{}
return nil, ErrThrottleTimeout
case <-ctx.Done():
timer.Stop()
t.backlogTokens <- struct{}{}
return nil, ctx.Err()
case <-t.tokens:
timer.Stop()
return t.releaseFunc(), nil
}
}
func (t *requestThrottle) releaseFunc() func() {
var once sync.Once
return func() {
once.Do(func() {
t.tokens <- struct{}{}
t.backlogTokens <- struct{}{}
})
}
}
type bufferedResponseWriter struct {
header http.Header
body bytes.Buffer
code int
}
func (w *bufferedResponseWriter) Header() http.Header {
return w.header
}
func (w *bufferedResponseWriter) Write(b []byte) (int, error) {
return w.body.Write(b)
}
func (w *bufferedResponseWriter) WriteHeader(code int) {
if w.code != 0 {
return
}
w.code = code
}