mirror of
https://github.com/tailscale/tailscale.git
synced 2026-04-03 06:02:30 -04:00
fixup! derp/derpserver: add per-connection receive rate limiting
This commit is contained in:
@@ -1211,12 +1211,10 @@ func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error {
|
||||
return fmt.Errorf("client %v: recvPacket: %v", c.key, err)
|
||||
}
|
||||
|
||||
// Rate limit non-DISCO packets via TCP backpressure. By blocking
|
||||
// here, we delay reading the next frame, causing the TCP receive
|
||||
// buffer to fill and the TCP window to shrink, which throttles the
|
||||
// sender. DISCO frames are exempt because they are small control
|
||||
// messages critical for direct connection establishment.
|
||||
if c.recvLim != nil && !disco.LooksLikeDiscoWrapper(contents) {
|
||||
// Rate limit via TCP backpressure. By blocking here, we delay
|
||||
// reading the next frame, causing the TCP receive buffer to fill
|
||||
// and the TCP window to shrink, which throttles the sender.
|
||||
if c.recvLim != nil {
|
||||
if err := c.recvLim.WaitN(c.ctx, len(contents)); err != nil {
|
||||
return nil // context canceled, connection closing
|
||||
}
|
||||
@@ -1540,12 +1538,7 @@ func (s *Server) noteClientActivity(c *sclient) {
|
||||
type ServerInfo = derp.ServerInfo
|
||||
|
||||
func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) error {
|
||||
si := ServerInfo{Version: derp.ProtocolVersion}
|
||||
if s.perClientRecvBytesPerSec > 0 {
|
||||
si.TokenBucketBytesPerSecond = s.perClientRecvBytesPerSec
|
||||
si.TokenBucketBytesBurst = s.perClientRecvBurst
|
||||
}
|
||||
msg, err := json.Marshal(si)
|
||||
msg, err := json.Marshal(ServerInfo{Version: derp.ProtocolVersion})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -954,11 +954,9 @@ func BenchmarkHyperLogLogEstimate(b *testing.B) {
|
||||
}
|
||||
|
||||
func TestPerClientRateLimit(t *testing.T) {
|
||||
// newServer creates a DERP server with a listener and returns a client factory.
|
||||
newServer := func(t *testing.T, bytesPerSec, burst int) (*Server, func(t *testing.T) *derp.Client) {
|
||||
newServer := func(t *testing.T, bytesPerSec, burst int) func(t *testing.T) *derp.Client {
|
||||
t.Helper()
|
||||
serverPriv := key.NewNode()
|
||||
s := New(serverPriv, logger.Discard)
|
||||
s := New(key.NewNode(), logger.Discard)
|
||||
if bytesPerSec > 0 {
|
||||
s.SetPerClientRateLimit(bytesPerSec, burst)
|
||||
}
|
||||
@@ -984,35 +982,47 @@ func TestPerClientRateLimit(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
newClient := func(t *testing.T) *derp.Client {
|
||||
return func(t *testing.T) *derp.Client {
|
||||
t.Helper()
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { conn.Close() })
|
||||
k := key.NewNode()
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
c, err := derp.NewClient(k, conn, brw, logger.Discard)
|
||||
c, err := derp.NewClient(key.NewNode(), conn, brw, logger.Discard)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient: %v", err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
return s, newClient
|
||||
}
|
||||
|
||||
// recvNPackets receives exactly n ReceivedPacket messages from c,
|
||||
// discarding any other message types (e.g. ServerInfoMessage).
|
||||
// It returns the time taken to receive all n data packets.
|
||||
recvNPackets := func(t *testing.T, c *derp.Client, n int) time.Duration {
|
||||
// sendRecv sends numPkts packets of pktSize bytes and returns how long
|
||||
// it takes to receive them all.
|
||||
sendRecv := func(t *testing.T, newClient func(*testing.T) *derp.Client, pktSize, numPkts int) time.Duration {
|
||||
t.Helper()
|
||||
sender := newClient(t)
|
||||
receiver := newClient(t)
|
||||
// Drain ServerInfoMessage.
|
||||
if _, err := receiver.Recv(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg := make([]byte, pktSize)
|
||||
go func() {
|
||||
for i := range numPkts {
|
||||
if err := sender.Send(receiver.PublicKey(), msg); err != nil {
|
||||
t.Errorf("Send(%d): %v", i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
start := time.Now()
|
||||
got := 0
|
||||
for got < n {
|
||||
m, err := c.Recv()
|
||||
for got < numPkts {
|
||||
m, err := receiver.Recv()
|
||||
if err != nil {
|
||||
t.Fatalf("Recv: %v (got %d/%d)", err, got, n)
|
||||
t.Fatalf("Recv: %v (got %d/%d)", err, got, numPkts)
|
||||
}
|
||||
if _, ok := m.(derp.ReceivedPacket); ok {
|
||||
got++
|
||||
@@ -1021,83 +1031,29 @@ func TestPerClientRateLimit(t *testing.T) {
|
||||
return time.Since(start)
|
||||
}
|
||||
|
||||
t.Run("non_disco_throttled", func(t *testing.T) {
|
||||
// Use a rate that will show measurable delay.
|
||||
// SetPerClientRateLimit clamps burst to max(burst, MaxPacketSize=64KB).
|
||||
// So with 100KB/s rate and 64KB effective burst, sending 128KB of data
|
||||
// should take at least ~640ms for the 64KB over burst.
|
||||
const bytesPerSec = 100_000
|
||||
_, newClient := newServer(t, bytesPerSec, bytesPerSec)
|
||||
sender := newClient(t)
|
||||
receiver := newClient(t)
|
||||
|
||||
// Drain the ServerInfoMessage from receiver before timing.
|
||||
if _, err := receiver.Recv(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("throttled", func(t *testing.T) {
|
||||
// Compare transfer time with and without rate limiting.
|
||||
// This avoids flaky absolute time thresholds.
|
||||
const pktSize = 1000
|
||||
const numPkts = 128 // 128KB total
|
||||
msg := make([]byte, pktSize)
|
||||
const numPkts = 128 // 128KB total, exceeds the 64KB effective burst
|
||||
|
||||
// Send all packets.
|
||||
for i := range numPkts {
|
||||
if err := sender.Send(receiver.PublicKey(), msg); err != nil {
|
||||
t.Fatalf("Send(%d): %v", i, err)
|
||||
}
|
||||
unlimited := sendRecv(t, newServer(t, 0, 0), pktSize, numPkts)
|
||||
limited := sendRecv(t, newServer(t, 100_000, 100_000), pktSize, numPkts)
|
||||
|
||||
t.Logf("unlimited=%v, limited=%v (ratio=%.0fx)", unlimited, limited, float64(limited)/float64(unlimited))
|
||||
|
||||
// Rate-limited transfer should take at least 10x longer.
|
||||
// In practice: ~280ms vs ~200µs (~1400x).
|
||||
if limited < 10*unlimited {
|
||||
t.Errorf("rate-limited transfer not slower enough: unlimited=%v, limited=%v", unlimited, limited)
|
||||
}
|
||||
|
||||
// Measure how long it takes to receive all data packets.
|
||||
elapsed := recvNPackets(t, receiver, numPkts)
|
||||
|
||||
// 128KB total, ~64KB effective burst, 100KB/s rate.
|
||||
// Should take meaningfully longer than without rate limiting.
|
||||
// Without rate limiting, the same data transfers in <1ms on loopback.
|
||||
if elapsed < 100*time.Millisecond {
|
||||
t.Errorf("expected receives to be throttled, but took only %v", elapsed)
|
||||
}
|
||||
t.Logf("received %d packets of %d bytes in %v", numPkts, pktSize, elapsed)
|
||||
})
|
||||
|
||||
t.Run("disco_not_throttled", func(t *testing.T) {
|
||||
// Same rate as above, but DISCO packets should bypass the limiter.
|
||||
// Send the same amount of data to contrast with the throttled case.
|
||||
const bytesPerSec = 100_000
|
||||
_, newClient := newServer(t, bytesPerSec, bytesPerSec)
|
||||
sender := newClient(t)
|
||||
receiver := newClient(t)
|
||||
|
||||
if _, err := receiver.Recv(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// disco.Magic (6 bytes) + 32 byte key + 24 byte nonce + payload
|
||||
discoPacket := make([]byte, 6+32+24+932) // ~1000 bytes total
|
||||
copy(discoPacket, "TS💬") // disco.Magic
|
||||
|
||||
const numPkts = 128
|
||||
for i := range numPkts {
|
||||
if err := sender.Send(receiver.PublicKey(), discoPacket); err != nil {
|
||||
t.Fatalf("Send(%d): %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := recvNPackets(t, receiver, numPkts)
|
||||
|
||||
// DISCO packets bypass the rate limiter; should complete quickly
|
||||
// (no 640ms+ delay like the non-DISCO case).
|
||||
if elapsed > 2*time.Second {
|
||||
t.Errorf("expected DISCO receives to be fast, but took %v", elapsed)
|
||||
}
|
||||
t.Logf("received %d DISCO packets in %v", numPkts, elapsed)
|
||||
})
|
||||
|
||||
t.Run("mesh_peer_exempt", func(t *testing.T) {
|
||||
// Verify the server would not assign a rate limiter to mesh peers.
|
||||
s, _ := newServer(t, 10_000, 10_000)
|
||||
s := New(key.NewNode(), logger.Discard)
|
||||
s.SetPerClientRateLimit(10_000, 10_000)
|
||||
defer s.Close()
|
||||
c := &sclient{s: s, canMesh: true}
|
||||
// accept() logic: s.perClientRecvBytesPerSec > 0 && !c.canMesh
|
||||
// For mesh peer (canMesh=true), condition is false → no limiter.
|
||||
if s.perClientRecvBytesPerSec > 0 && !c.canMesh {
|
||||
t.Error("mesh peer should be exempt from rate limiting")
|
||||
}
|
||||
@@ -1107,7 +1063,8 @@ func TestPerClientRateLimit(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("zero_config_no_limiter", func(t *testing.T) {
|
||||
s, _ := newServer(t, 0, 0)
|
||||
s := New(key.NewNode(), logger.Discard)
|
||||
defer s.Close()
|
||||
if s.perClientRecvBytesPerSec != 0 {
|
||||
t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user