net/batching: reset Buffers len in WriteBatchTo

In case we land on this branch during a goto retry. Also, protect
Geneve offset from mutation across retries.

Fixes #19927

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2026-05-29 15:24:15 -07:00
committed by Jordan Whited
parent 3e34e721e8
commit 8a294e3c34
2 changed files with 163 additions and 2 deletions

View File

@@ -296,15 +296,20 @@ func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, ge
if c.txOffload.Load() && (!neverGSOEqualTail || len(buffs) >= appendSentinelTailBatchSizeThreshold) {
n = c.coalesceMessages(batch.ua, geneve, buffs, batch.msgs, offset, neverGSOEqualTail)
} else {
mutableOffset := offset // don't mutate offset across retries
vniIsSet := geneve.VNI.IsSet()
if vniIsSet {
offset -= packet.GeneveFixedHeaderLength
mutableOffset -= packet.GeneveFixedHeaderLength
}
for i := range buffs {
if vniIsSet {
geneve.Encode(buffs[i])
}
batch.msgs[i].Buffers[0] = buffs[i][offset:]
batch.msgs[i].Buffers[0] = buffs[i][mutableOffset:]
// Buffers length may be > 1 (scatter-gather) if we passed through
// coalesceMessages during a first pass, and landed here as part of
// goto retry.
batch.msgs[i].Buffers = batch.msgs[i].Buffers[:1]
batch.msgs[i].Addr = batch.ua
batch.msgs[i].OOB = batch.msgs[i].OOB[:0]
}

View File

@@ -5,9 +5,14 @@
import (
"encoding/binary"
"errors"
"io"
"math"
"net"
"net/netip"
"os"
"sync"
"sync/atomic"
"testing"
"unsafe"
@@ -15,6 +20,7 @@
"github.com/tailscale/wireguard-go/conn"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"tailscale.com/net/neterror"
"tailscale.com/net/packet"
)
@@ -423,6 +429,156 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
}
}
// fakeBatchWriter is an xnetBatchReaderWriter that records the Buffers length
// of each message handed to WriteBatch, and optionally fails the first call
// with an error that triggers neterror.ShouldDisableUDPGSO.
type fakeBatchWriter struct {
gotBuffersLen [][]int // Buffers len of each msg, per WriteBatch call
failFirst bool
}
func (f *fakeBatchWriter) ReadBatch([]ipv6.Message, int) (int, error) { return 0, nil }
func (f *fakeBatchWriter) WriteBatch(msgs []ipv6.Message, _ int) (int, error) {
snap := make([]int, len(msgs))
for i := range msgs {
snap[i] = len(msgs[i].Buffers)
}
f.gotBuffersLen = append(f.gotBuffersLen, snap)
if f.failFirst && len(f.gotBuffersLen) == 1 {
return 0, &os.SyscallError{Syscall: "sendmmsg", Err: unix.EIO}
}
return len(msgs), nil
}
// Test_linuxBatchingConn_WriteBatchTo_resetsBuffersOnGSORetry verifies that
// when a coalesced (scatter-gather) write fails and triggers the GSO-disable
// goto retry, the non-coalesce retry pass resets each message's Buffers back to
// length 1 rather than leaving stale iovecs appended by coalesceMessages.
func Test_linuxBatchingConn_WriteBatchTo_resetsBuffersOnGSORetry(t *testing.T) {
uc, err := net.ListenUDP("udp4", nil) // only for pc.LocalAddr() in the error path
if err != nil {
t.Fatal(err)
}
defer uc.Close()
xpc := &fakeBatchWriter{failFirst: true}
c := &linuxBatchingConn{
pc: uc,
xpc: xpc,
sendBatchPool: sync.Pool{New: func() any {
ua := &net.UDPAddr{IP: make([]byte, 16)}
msgs := make([]ipv6.Message, 8)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{ua: ua, msgs: msgs}
}},
}
c.txOffload.Store(true) // force the coalesce path on the first pass
// Two equal-length buffs coalesce into a single msg whose Buffers grows
// to len 2 (scatter-gather) on the first pass.
buffs := [][]byte{make([]byte, 32), make([]byte, 32)}
err = c.WriteBatchTo(buffs, netip.MustParseAddrPort("127.0.0.1:1"), packet.GeneveHeader{}, 0)
// The retry path always returns ErrUDPGSODisabled wrapping the retry's
// result (nil here).
if _, ok := errors.AsType[neterror.ErrUDPGSODisabled](err); !ok {
t.Fatalf("got %v, want ErrUDPGSODisabled", err)
}
if len(xpc.gotBuffersLen) != 2 {
t.Fatalf("got %d WriteBatch calls, want 2", len(xpc.gotBuffersLen))
}
// First (coalesced) call: one msg with 2 iovecs — confirms the precondition
// that coalesceMessages grew Buffers past length 1.
if got := xpc.gotBuffersLen[0]; len(got) != 1 || got[0] != 2 {
t.Fatalf("first call buffers = %v, want [2]", got)
}
// Retry (non-coalesce) call: sends one msg per buff...
if got := len(xpc.gotBuffersLen[1]); got != len(buffs) {
t.Fatalf("retry call sent %d msgs, want %d", got, len(buffs))
}
// ...and the fix must have reset every msg's Buffers back to len 1.
for i, n := range xpc.gotBuffersLen[1] {
if n != 1 {
t.Errorf("retry msg[%d] Buffers len = %d, want 1", i, n)
}
}
}
// Test_linuxBatchingConn_WriteBatchTo_offsetStableOnNonCoalesceRetry verifies
// that the Geneve header offset adjustment in the non-coalesce path is derived
// fresh from the original offset on each pass, rather than accumulating across a
// goto retry. The non-coalesce branch runs on both passes when neverGSOEqualTail
// is set and the batch is small enough to skip coalescing: the first pass fails
// with an error that disables GSO, and the retry re-enters the same branch.
// Since callers pass offset == GeneveFixedHeaderLength, a stale (accumulating)
// offset would underflow to -GeneveFixedHeaderLength and panic on buffs[i][-8:].
func Test_linuxBatchingConn_WriteBatchTo_offsetStableOnNonCoalesceRetry(t *testing.T) {
uc, err := net.ListenUDP("udp4", nil) // only for pc.LocalAddr() in the error path
if err != nil {
t.Fatal(err)
}
defer uc.Close()
xpc := &fakeBatchWriter{failFirst: true}
c := &linuxBatchingConn{
pc: uc,
xpc: xpc,
sendBatchPool: sync.Pool{New: func() any {
ua := &net.UDPAddr{IP: make([]byte, 16)}
msgs := make([]ipv6.Message, appendSentinelTailBatchSizeThreshold)
for i := range msgs {
msgs[i].Buffers = make([][]byte, 1)
msgs[i].Addr = ua
msgs[i].OOB = make([]byte, controlMessageSize)
}
return &sendBatch{ua: ua, msgs: msgs}
}},
}
c.txOffload.Store(true)
// neverGSOEqualTail set + a sub-threshold batch forces the non-coalesce
// path while txOffload is still enabled, so the GSO-disable retry re-enters
// the non-coalesce branch a second time.
var neverGSOEqualTail atomic.Bool
neverGSOEqualTail.Store(true)
c.neverGSOEqualTail = &neverGSOEqualTail
// VNI set so the non-coalesce branch performs the offset -= GeneveFixedHeaderLength
// adjustment; offset == GeneveFixedHeaderLength as the production caller requires.
geneve := packet.GeneveHeader{Protocol: packet.GeneveProtocolWireGuard}
geneve.VNI.Set(1)
offset := packet.GeneveFixedHeaderLength
// Stay below appendSentinelTailBatchSizeThreshold so coalescing is skipped
// and we take the non-coalesce branch on both passes.
const nBuffs = appendSentinelTailBatchSizeThreshold - 1
buffs := make([][]byte, nBuffs)
for i := range buffs {
buffs[i] = make([]byte, 32)
}
// Must not panic: each pass recomputes the offset from the original.
err = c.WriteBatchTo(buffs, netip.MustParseAddrPort("127.0.0.1:1"), geneve, offset)
if _, ok := errors.AsType[neterror.ErrUDPGSODisabled](err); !ok {
t.Fatalf("got %v, want ErrUDPGSODisabled", err)
}
if len(xpc.gotBuffersLen) != 2 {
t.Fatalf("got %d WriteBatch calls, want 2 (initial + retry)", len(xpc.gotBuffersLen))
}
// Both passes take the non-coalesce branch: one msg per buff, no coalescing.
for call, got := range xpc.gotBuffersLen {
if len(got) != len(buffs) {
t.Errorf("call %d sent %d msgs, want %d", call, len(got), len(buffs))
}
}
}
func TestMinReadBatchMsgsLen(t *testing.T) {
// So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is
// shaped for wireguard-go to control packet memory, these values should be