Files
tailscale/derp/client_test.go
Mike O'Driscoll f52c1e3615 derp: use AvailableBuffer for WriteFrameHeader, consolidate tests (#19101)
Use bufio.Writer.AvailableBuffer to write the frame header directly
into bufio's internal buffer as a single append+Write, avoiding 5
separate WriteByte calls. Fall back to the existing writeUint32
byte-at-a-time path when the buffer has insufficient space.

```
name                  old ns/op  new ns/op  speedup
WriteFrameHeader-8    18.8       7.8        ~2.4x
(0 allocs/op in both)
```

Add TestWriteFrameHeader with correctness
checks, allocation assertions, and coverage of both fast and slow
write paths. Move BenchmarkReadFrameHeader from client_test.go to
derp_test.go alongside BenchmarkWriteFrameHeader, co-located with
the functions under test.

Updates tailscale/corp#38509

Signed-off-by: Mike O'Driscoll <mikeo@tailscale.com>
2026-03-24 18:08:01 -04:00

205 lines
4.2 KiB
Go

// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package derp
import (
"bufio"
"bytes"
"net"
"reflect"
"sync"
"testing"
"time"
"tailscale.com/tstest"
"tailscale.com/types/key"
)
type dummyNetConn struct {
net.Conn
}
func (dummyNetConn) SetReadDeadline(time.Time) error { return nil }
func TestClientRecv(t *testing.T) {
tests := []struct {
name string
input []byte
want any
}{
{
name: "ping",
input: []byte{
byte(FramePing), 0, 0, 0, 8,
1, 2, 3, 4, 5, 6, 7, 8,
},
want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8},
},
{
name: "pong",
input: []byte{
byte(FramePong), 0, 0, 0, 8,
1, 2, 3, 4, 5, 6, 7, 8,
},
want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8},
},
{
name: "health_bad",
input: []byte{
byte(FrameHealth), 0, 0, 0, 3,
byte('B'), byte('A'), byte('D'),
},
want: HealthMessage{Problem: "BAD"},
},
{
name: "health_ok",
input: []byte{
byte(FrameHealth), 0, 0, 0, 0,
},
want: HealthMessage{},
},
{
name: "server_restarting",
input: []byte{
byte(FrameRestarting), 0, 0, 0, 8,
0, 0, 0, 1,
0, 0, 0, 2,
},
want: ServerRestartingMessage{
ReconnectIn: 1 * time.Millisecond,
TryFor: 2 * time.Millisecond,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
nc: dummyNetConn{},
br: bufio.NewReader(bytes.NewReader(tt.input)),
logf: t.Logf,
clock: &tstest.Clock{},
}
got, err := c.Recv()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %#v; want %#v", got, tt.want)
}
})
}
}
func TestClientSendPing(t *testing.T) {
var buf bytes.Buffer
c := &Client{
bw: bufio.NewWriter(&buf),
}
if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
t.Fatal(err)
}
want := []byte{
byte(FramePing), 0, 0, 0, 8,
1, 2, 3, 4, 5, 6, 7, 8,
}
if !bytes.Equal(buf.Bytes(), want) {
t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
}
}
func TestClientSendPong(t *testing.T) {
var buf bytes.Buffer
c := &Client{
bw: bufio.NewWriter(&buf),
}
if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
t.Fatal(err)
}
want := []byte{
byte(FramePong), 0, 0, 0, 8,
1, 2, 3, 4, 5, 6, 7, 8,
}
if !bytes.Equal(buf.Bytes(), want) {
t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want)
}
}
type countWriter struct {
mu sync.Mutex
writes int
bytes int64
}
func (w *countWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
w.writes++
w.bytes += int64(len(p))
return len(p), nil
}
func (w *countWriter) Stats() (writes int, bytes int64) {
w.mu.Lock()
defer w.mu.Unlock()
return w.writes, w.bytes
}
func (w *countWriter) ResetStats() {
w.mu.Lock()
defer w.mu.Unlock()
w.writes, w.bytes = 0, 0
}
func TestClientSendRateLimiting(t *testing.T) {
cw := new(countWriter)
c := &Client{
bw: bufio.NewWriter(cw),
clock: &tstest.Clock{},
}
c.setSendRateLimiter(ServerInfoMessage{})
pkt := make([]byte, 1000)
if err := c.send(key.NodePublic{}, pkt); err != nil {
t.Fatal(err)
}
writes1, bytes1 := cw.Stats()
if writes1 != 1 {
t.Errorf("writes = %v, want 1", writes1)
}
// Flood should all succeed.
cw.ResetStats()
for range 1000 {
if err := c.send(key.NodePublic{}, pkt); err != nil {
t.Fatal(err)
}
}
writes1K, bytes1K := cw.Stats()
if writes1K != 1000 {
t.Logf("writes = %v; want 1000", writes1K)
}
if got, want := bytes1K, bytes1*1000; got != want {
t.Logf("bytes = %v; want %v", got, want)
}
// Set a rate limiter
cw.ResetStats()
c.setSendRateLimiter(ServerInfoMessage{
TokenBucketBytesPerSecond: 1,
TokenBucketBytesBurst: int(bytes1 * 2),
})
for range 1000 {
if err := c.send(key.NodePublic{}, pkt); err != nil {
t.Fatal(err)
}
}
writesLimited, bytesLimited := cw.Stats()
if writesLimited == 0 || writesLimited == writes1K {
t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited)
}
if bytesLimited < bytes1*2 || bytesLimited >= bytes1K {
t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K)
}
}