mirror of
https://github.com/navidrome/navidrome.git
synced 2026-06-11 00:56:16 -04:00
150 lines
3.6 KiB
Go
150 lines
3.6 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"maps"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/navidrome/navidrome/conf"
|
|
"github.com/navidrome/navidrome/log"
|
|
)
|
|
|
|
var (
|
|
ErrThrottleCapacityExceeded = errors.New("throttle: capacity exceeded")
|
|
ErrThrottleTimeout = errors.New("throttle: backlog timeout")
|
|
)
|
|
|
|
type requestThrottle struct {
|
|
tokens chan struct{}
|
|
backlogTokens chan struct{}
|
|
backlogTimeout time.Duration
|
|
}
|
|
|
|
// ThrottleBacklog creates a Chi-compatible middleware that limits concurrent
|
|
// request processing. Unlike Chi's ThrottleBacklog, it buffers the handler's
|
|
// response while holding the token, releases it, then flushes the buffer to
|
|
// the client with a write deadline. This prevents slow clients from holding
|
|
// throttle capacity.
|
|
//
|
|
// Because it buffers the entire response in memory, this middleware should only
|
|
// be used for endpoints that return small responses (e.g., artwork images). Do
|
|
// not use it for audio streaming or download endpoints.
|
|
func ThrottleBacklog(limit, backlogLimit int, backlogTimeout time.Duration) func(http.Handler) http.Handler {
|
|
if limit <= 0 {
|
|
return func(next http.Handler) http.Handler { return next }
|
|
}
|
|
if !conf.Server.DevArtworkThrottleBuffered {
|
|
return middleware.ThrottleBacklog(limit, backlogLimit, backlogTimeout)
|
|
}
|
|
t := &requestThrottle{
|
|
tokens: make(chan struct{}, limit),
|
|
backlogTokens: make(chan struct{}, limit+backlogLimit),
|
|
backlogTimeout: backlogTimeout,
|
|
}
|
|
for range limit {
|
|
t.tokens <- struct{}{}
|
|
}
|
|
for range limit + backlogLimit {
|
|
t.backlogTokens <- struct{}{}
|
|
}
|
|
return t.handler
|
|
}
|
|
|
|
func (t *requestThrottle) handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
release, err := t.acquire(ctx)
|
|
if err != nil {
|
|
switch {
|
|
case errors.Is(err, ErrThrottleCapacityExceeded):
|
|
log.Warn(ctx, "Request throttle capacity exceeded", "path", r.URL.Path)
|
|
case errors.Is(err, ErrThrottleTimeout):
|
|
log.Warn(ctx, "Request throttle backlog timeout", "path", r.URL.Path)
|
|
}
|
|
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
buf := &bufferedResponseWriter{header: make(http.Header)}
|
|
func() {
|
|
defer release()
|
|
next.ServeHTTP(buf, r)
|
|
}()
|
|
|
|
maps.Copy(w.Header(), buf.header)
|
|
if buf.code > 0 {
|
|
w.WriteHeader(buf.code)
|
|
}
|
|
if _, err := w.Write(buf.body.Bytes()); err != nil {
|
|
log.Warn(ctx, "Error writing throttled response", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (t *requestThrottle) acquire(ctx context.Context) (release func(), err error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-t.backlogTokens:
|
|
default:
|
|
return nil, ErrThrottleCapacityExceeded
|
|
}
|
|
|
|
select {
|
|
case <-t.tokens:
|
|
return t.releaseFunc(), nil
|
|
default:
|
|
}
|
|
|
|
timer := time.NewTimer(t.backlogTimeout)
|
|
select {
|
|
case <-timer.C:
|
|
t.backlogTokens <- struct{}{}
|
|
return nil, ErrThrottleTimeout
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
t.backlogTokens <- struct{}{}
|
|
return nil, ctx.Err()
|
|
case <-t.tokens:
|
|
timer.Stop()
|
|
return t.releaseFunc(), nil
|
|
}
|
|
}
|
|
|
|
func (t *requestThrottle) releaseFunc() func() {
|
|
var once sync.Once
|
|
return func() {
|
|
once.Do(func() {
|
|
t.tokens <- struct{}{}
|
|
t.backlogTokens <- struct{}{}
|
|
})
|
|
}
|
|
}
|
|
|
|
type bufferedResponseWriter struct {
|
|
header http.Header
|
|
body bytes.Buffer
|
|
code int
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) Header() http.Header {
|
|
return w.header
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) Write(b []byte) (int, error) {
|
|
return w.body.Write(b)
|
|
}
|
|
|
|
func (w *bufferedResponseWriter) WriteHeader(code int) {
|
|
if w.code != 0 {
|
|
return
|
|
}
|
|
w.code = code
|
|
}
|