From e739b5e10463d35da8e1d04afe6fcf5fad06e6d2 Mon Sep 17 00:00:00 2001 From: Mike O'Driscoll Date: Thu, 2 Apr 2026 19:32:23 +0000 Subject: [PATCH] fixup! derp/derpserver: add per-connection receive rate limiting --- derp/derpserver/derpserver.go | 17 ++-- derp/derpserver/derpserver_test.go | 129 ++++++++++------------------- 2 files changed, 48 insertions(+), 98 deletions(-) diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index eab47f8e9..50e2ec884 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -1211,12 +1211,10 @@ func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { return fmt.Errorf("client %v: recvPacket: %v", c.key, err) } - // Rate limit non-DISCO packets via TCP backpressure. By blocking - // here, we delay reading the next frame, causing the TCP receive - // buffer to fill and the TCP window to shrink, which throttles the - // sender. DISCO frames are exempt because they are small control - // messages critical for direct connection establishment. - if c.recvLim != nil && !disco.LooksLikeDiscoWrapper(contents) { + // Rate limit via TCP backpressure. By blocking here, we delay + // reading the next frame, causing the TCP receive buffer to fill + // and the TCP window to shrink, which throttles the sender. + if c.recvLim != nil { if err := c.recvLim.WaitN(c.ctx, len(contents)); err != nil { return nil // context canceled, connection closing } @@ -1540,12 +1538,7 @@ func (s *Server) noteClientActivity(c *sclient) { type ServerInfo = derp.ServerInfo func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) error { - si := ServerInfo{Version: derp.ProtocolVersion} - if s.perClientRecvBytesPerSec > 0 { - si.TokenBucketBytesPerSecond = s.perClientRecvBytesPerSec - si.TokenBucketBytesBurst = s.perClientRecvBurst - } - msg, err := json.Marshal(si) + msg, err := json.Marshal(ServerInfo{Version: derp.ProtocolVersion}) if err != nil { return err } diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 75a135b0e..93d9c0c8b 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -954,11 +954,9 @@ func BenchmarkHyperLogLogEstimate(b *testing.B) { } func TestPerClientRateLimit(t *testing.T) { - // newServer creates a DERP server with a listener and returns a client factory. - newServer := func(t *testing.T, bytesPerSec, burst int) (*Server, func(t *testing.T) *derp.Client) { + newServer := func(t *testing.T, bytesPerSec, burst int) func(t *testing.T) *derp.Client { t.Helper() - serverPriv := key.NewNode() - s := New(serverPriv, logger.Discard) + s := New(key.NewNode(), logger.Discard) if bytesPerSec > 0 { s.SetPerClientRateLimit(bytesPerSec, burst) } @@ -984,35 +982,47 @@ func TestPerClientRateLimit(t *testing.T) { } }() - newClient := func(t *testing.T) *derp.Client { + return func(t *testing.T) *derp.Client { t.Helper() conn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatal(err) } t.Cleanup(func() { conn.Close() }) - k := key.NewNode() brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - c, err := derp.NewClient(k, conn, brw, logger.Discard) + c, err := derp.NewClient(key.NewNode(), conn, brw, logger.Discard) if err != nil { t.Fatalf("NewClient: %v", err) } return c } - return s, newClient } - // recvNPackets receives exactly n ReceivedPacket messages from c, - // discarding any other message types (e.g. ServerInfoMessage). - // It returns the time taken to receive all n data packets. - recvNPackets := func(t *testing.T, c *derp.Client, n int) time.Duration { + // sendRecv sends numPkts packets of pktSize bytes and returns how long + // it takes to receive them all. + sendRecv := func(t *testing.T, newClient func(*testing.T) *derp.Client, pktSize, numPkts int) time.Duration { t.Helper() + sender := newClient(t) + receiver := newClient(t) + // Drain ServerInfoMessage. + if _, err := receiver.Recv(); err != nil { + t.Fatal(err) + } + msg := make([]byte, pktSize) + go func() { + for i := range numPkts { + if err := sender.Send(receiver.PublicKey(), msg); err != nil { + t.Errorf("Send(%d): %v", i, err) + return + } + } + }() start := time.Now() got := 0 - for got < n { - m, err := c.Recv() + for got < numPkts { + m, err := receiver.Recv() if err != nil { - t.Fatalf("Recv: %v (got %d/%d)", err, got, n) + t.Fatalf("Recv: %v (got %d/%d)", err, got, numPkts) } if _, ok := m.(derp.ReceivedPacket); ok { got++ @@ -1021,83 +1031,29 @@ func TestPerClientRateLimit(t *testing.T) { return time.Since(start) } - t.Run("non_disco_throttled", func(t *testing.T) { - // Use a rate that will show measurable delay. - // SetPerClientRateLimit clamps burst to max(burst, MaxPacketSize=64KB). - // So with 100KB/s rate and 64KB effective burst, sending 128KB of data - // should take at least ~640ms for the 64KB over burst. - const bytesPerSec = 100_000 - _, newClient := newServer(t, bytesPerSec, bytesPerSec) - sender := newClient(t) - receiver := newClient(t) - - // Drain the ServerInfoMessage from receiver before timing. - if _, err := receiver.Recv(); err != nil { - t.Fatal(err) - } - + t.Run("throttled", func(t *testing.T) { + // Compare transfer time with and without rate limiting. + // This avoids flaky absolute time thresholds. const pktSize = 1000 - const numPkts = 128 // 128KB total - msg := make([]byte, pktSize) + const numPkts = 128 // 128KB total, exceeds the 64KB effective burst - // Send all packets. - for i := range numPkts { - if err := sender.Send(receiver.PublicKey(), msg); err != nil { - t.Fatalf("Send(%d): %v", i, err) - } + unlimited := sendRecv(t, newServer(t, 0, 0), pktSize, numPkts) + limited := sendRecv(t, newServer(t, 100_000, 100_000), pktSize, numPkts) + + t.Logf("unlimited=%v, limited=%v (ratio=%.0fx)", unlimited, limited, float64(limited)/float64(unlimited)) + + // Rate-limited transfer should take at least 10x longer. + // In practice: ~280ms vs ~200µs (~1400x). + if limited < 10*unlimited { + t.Errorf("rate-limited transfer not slower enough: unlimited=%v, limited=%v", unlimited, limited) } - - // Measure how long it takes to receive all data packets. - elapsed := recvNPackets(t, receiver, numPkts) - - // 128KB total, ~64KB effective burst, 100KB/s rate. - // Should take meaningfully longer than without rate limiting. - // Without rate limiting, the same data transfers in <1ms on loopback. - if elapsed < 100*time.Millisecond { - t.Errorf("expected receives to be throttled, but took only %v", elapsed) - } - t.Logf("received %d packets of %d bytes in %v", numPkts, pktSize, elapsed) - }) - - t.Run("disco_not_throttled", func(t *testing.T) { - // Same rate as above, but DISCO packets should bypass the limiter. - // Send the same amount of data to contrast with the throttled case. - const bytesPerSec = 100_000 - _, newClient := newServer(t, bytesPerSec, bytesPerSec) - sender := newClient(t) - receiver := newClient(t) - - if _, err := receiver.Recv(); err != nil { - t.Fatal(err) - } - - // disco.Magic (6 bytes) + 32 byte key + 24 byte nonce + payload - discoPacket := make([]byte, 6+32+24+932) // ~1000 bytes total - copy(discoPacket, "TS💬") // disco.Magic - - const numPkts = 128 - for i := range numPkts { - if err := sender.Send(receiver.PublicKey(), discoPacket); err != nil { - t.Fatalf("Send(%d): %v", i, err) - } - } - - elapsed := recvNPackets(t, receiver, numPkts) - - // DISCO packets bypass the rate limiter; should complete quickly - // (no 640ms+ delay like the non-DISCO case). - if elapsed > 2*time.Second { - t.Errorf("expected DISCO receives to be fast, but took %v", elapsed) - } - t.Logf("received %d DISCO packets in %v", numPkts, elapsed) }) t.Run("mesh_peer_exempt", func(t *testing.T) { - // Verify the server would not assign a rate limiter to mesh peers. - s, _ := newServer(t, 10_000, 10_000) + s := New(key.NewNode(), logger.Discard) + s.SetPerClientRateLimit(10_000, 10_000) + defer s.Close() c := &sclient{s: s, canMesh: true} - // accept() logic: s.perClientRecvBytesPerSec > 0 && !c.canMesh - // For mesh peer (canMesh=true), condition is false → no limiter. if s.perClientRecvBytesPerSec > 0 && !c.canMesh { t.Error("mesh peer should be exempt from rate limiting") } @@ -1107,7 +1063,8 @@ func TestPerClientRateLimit(t *testing.T) { }) t.Run("zero_config_no_limiter", func(t *testing.T) { - s, _ := newServer(t, 0, 0) + s := New(key.NewNode(), logger.Discard) + defer s.Close() if s.perClientRecvBytesPerSec != 0 { t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec) }