From 8a294e3c345a969bd6e915accea39cc9aeaa2d96 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Fri, 29 May 2026 15:24:15 -0700 Subject: [PATCH] net/batching: reset Buffers len in WriteBatchTo In case we land on this branch during a goto retry. Also, protect Geneve offset from mutation across retries. Fixes #19927 Signed-off-by: Jordan Whited --- net/batching/conn_linux.go | 9 +- net/batching/conn_linux_test.go | 156 ++++++++++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/net/batching/conn_linux.go b/net/batching/conn_linux.go index 1718e98dd..f094f4b65 100644 --- a/net/batching/conn_linux.go +++ b/net/batching/conn_linux.go @@ -296,15 +296,20 @@ func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, ge if c.txOffload.Load() && (!neverGSOEqualTail || len(buffs) >= appendSentinelTailBatchSizeThreshold) { n = c.coalesceMessages(batch.ua, geneve, buffs, batch.msgs, offset, neverGSOEqualTail) } else { + mutableOffset := offset // don't mutate offset across retries vniIsSet := geneve.VNI.IsSet() if vniIsSet { - offset -= packet.GeneveFixedHeaderLength + mutableOffset -= packet.GeneveFixedHeaderLength } for i := range buffs { if vniIsSet { geneve.Encode(buffs[i]) } - batch.msgs[i].Buffers[0] = buffs[i][offset:] + batch.msgs[i].Buffers[0] = buffs[i][mutableOffset:] + // Buffers length may be > 1 (scatter-gather) if we passed through + // coalesceMessages during a first pass, and landed here as part of + // goto retry. + batch.msgs[i].Buffers = batch.msgs[i].Buffers[:1] batch.msgs[i].Addr = batch.ua batch.msgs[i].OOB = batch.msgs[i].OOB[:0] } diff --git a/net/batching/conn_linux_test.go b/net/batching/conn_linux_test.go index 857c3d9d7..633e5d607 100644 --- a/net/batching/conn_linux_test.go +++ b/net/batching/conn_linux_test.go @@ -5,9 +5,14 @@ import ( "encoding/binary" + "errors" "io" "math" "net" + "net/netip" + "os" + "sync" + "sync/atomic" "testing" "unsafe" @@ -15,6 +20,7 @@ "github.com/tailscale/wireguard-go/conn" "golang.org/x/net/ipv6" "golang.org/x/sys/unix" + "tailscale.com/net/neterror" "tailscale.com/net/packet" ) @@ -423,6 +429,156 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { } } +// fakeBatchWriter is an xnetBatchReaderWriter that records the Buffers length +// of each message handed to WriteBatch, and optionally fails the first call +// with an error that triggers neterror.ShouldDisableUDPGSO. +type fakeBatchWriter struct { + gotBuffersLen [][]int // Buffers len of each msg, per WriteBatch call + failFirst bool +} + +func (f *fakeBatchWriter) ReadBatch([]ipv6.Message, int) (int, error) { return 0, nil } + +func (f *fakeBatchWriter) WriteBatch(msgs []ipv6.Message, _ int) (int, error) { + snap := make([]int, len(msgs)) + for i := range msgs { + snap[i] = len(msgs[i].Buffers) + } + f.gotBuffersLen = append(f.gotBuffersLen, snap) + if f.failFirst && len(f.gotBuffersLen) == 1 { + return 0, &os.SyscallError{Syscall: "sendmmsg", Err: unix.EIO} + } + return len(msgs), nil +} + +// Test_linuxBatchingConn_WriteBatchTo_resetsBuffersOnGSORetry verifies that +// when a coalesced (scatter-gather) write fails and triggers the GSO-disable +// goto retry, the non-coalesce retry pass resets each message's Buffers back to +// length 1 rather than leaving stale iovecs appended by coalesceMessages. +func Test_linuxBatchingConn_WriteBatchTo_resetsBuffersOnGSORetry(t *testing.T) { + uc, err := net.ListenUDP("udp4", nil) // only for pc.LocalAddr() in the error path + if err != nil { + t.Fatal(err) + } + defer uc.Close() + + xpc := &fakeBatchWriter{failFirst: true} + c := &linuxBatchingConn{ + pc: uc, + xpc: xpc, + sendBatchPool: sync.Pool{New: func() any { + ua := &net.UDPAddr{IP: make([]byte, 16)} + msgs := make([]ipv6.Message, 8) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + msgs[i].OOB = make([]byte, controlMessageSize) + } + return &sendBatch{ua: ua, msgs: msgs} + }}, + } + c.txOffload.Store(true) // force the coalesce path on the first pass + + // Two equal-length buffs coalesce into a single msg whose Buffers grows + // to len 2 (scatter-gather) on the first pass. + buffs := [][]byte{make([]byte, 32), make([]byte, 32)} + + err = c.WriteBatchTo(buffs, netip.MustParseAddrPort("127.0.0.1:1"), packet.GeneveHeader{}, 0) + + // The retry path always returns ErrUDPGSODisabled wrapping the retry's + // result (nil here). + if _, ok := errors.AsType[neterror.ErrUDPGSODisabled](err); !ok { + t.Fatalf("got %v, want ErrUDPGSODisabled", err) + } + if len(xpc.gotBuffersLen) != 2 { + t.Fatalf("got %d WriteBatch calls, want 2", len(xpc.gotBuffersLen)) + } + // First (coalesced) call: one msg with 2 iovecs — confirms the precondition + // that coalesceMessages grew Buffers past length 1. + if got := xpc.gotBuffersLen[0]; len(got) != 1 || got[0] != 2 { + t.Fatalf("first call buffers = %v, want [2]", got) + } + // Retry (non-coalesce) call: sends one msg per buff... + if got := len(xpc.gotBuffersLen[1]); got != len(buffs) { + t.Fatalf("retry call sent %d msgs, want %d", got, len(buffs)) + } + // ...and the fix must have reset every msg's Buffers back to len 1. + for i, n := range xpc.gotBuffersLen[1] { + if n != 1 { + t.Errorf("retry msg[%d] Buffers len = %d, want 1", i, n) + } + } +} + +// Test_linuxBatchingConn_WriteBatchTo_offsetStableOnNonCoalesceRetry verifies +// that the Geneve header offset adjustment in the non-coalesce path is derived +// fresh from the original offset on each pass, rather than accumulating across a +// goto retry. The non-coalesce branch runs on both passes when neverGSOEqualTail +// is set and the batch is small enough to skip coalescing: the first pass fails +// with an error that disables GSO, and the retry re-enters the same branch. +// Since callers pass offset == GeneveFixedHeaderLength, a stale (accumulating) +// offset would underflow to -GeneveFixedHeaderLength and panic on buffs[i][-8:]. +func Test_linuxBatchingConn_WriteBatchTo_offsetStableOnNonCoalesceRetry(t *testing.T) { + uc, err := net.ListenUDP("udp4", nil) // only for pc.LocalAddr() in the error path + if err != nil { + t.Fatal(err) + } + defer uc.Close() + + xpc := &fakeBatchWriter{failFirst: true} + c := &linuxBatchingConn{ + pc: uc, + xpc: xpc, + sendBatchPool: sync.Pool{New: func() any { + ua := &net.UDPAddr{IP: make([]byte, 16)} + msgs := make([]ipv6.Message, appendSentinelTailBatchSizeThreshold) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Addr = ua + msgs[i].OOB = make([]byte, controlMessageSize) + } + return &sendBatch{ua: ua, msgs: msgs} + }}, + } + c.txOffload.Store(true) + // neverGSOEqualTail set + a sub-threshold batch forces the non-coalesce + // path while txOffload is still enabled, so the GSO-disable retry re-enters + // the non-coalesce branch a second time. + var neverGSOEqualTail atomic.Bool + neverGSOEqualTail.Store(true) + c.neverGSOEqualTail = &neverGSOEqualTail + + // VNI set so the non-coalesce branch performs the offset -= GeneveFixedHeaderLength + // adjustment; offset == GeneveFixedHeaderLength as the production caller requires. + geneve := packet.GeneveHeader{Protocol: packet.GeneveProtocolWireGuard} + geneve.VNI.Set(1) + offset := packet.GeneveFixedHeaderLength + + // Stay below appendSentinelTailBatchSizeThreshold so coalescing is skipped + // and we take the non-coalesce branch on both passes. + const nBuffs = appendSentinelTailBatchSizeThreshold - 1 + buffs := make([][]byte, nBuffs) + for i := range buffs { + buffs[i] = make([]byte, 32) + } + + // Must not panic: each pass recomputes the offset from the original. + err = c.WriteBatchTo(buffs, netip.MustParseAddrPort("127.0.0.1:1"), geneve, offset) + + if _, ok := errors.AsType[neterror.ErrUDPGSODisabled](err); !ok { + t.Fatalf("got %v, want ErrUDPGSODisabled", err) + } + if len(xpc.gotBuffersLen) != 2 { + t.Fatalf("got %d WriteBatch calls, want 2 (initial + retry)", len(xpc.gotBuffersLen)) + } + // Both passes take the non-coalesce branch: one msg per buff, no coalescing. + for call, got := range xpc.gotBuffersLen { + if len(got) != len(buffs) { + t.Errorf("call %d sent %d msgs, want %d", call, len(got), len(buffs)) + } + } +} + func TestMinReadBatchMsgsLen(t *testing.T) { // So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is // shaped for wireguard-go to control packet memory, these values should be