net/{batching,sockopts,udprelay},wgengine/magicsock: create batching.Conn early

We used to create a net.UDPConn by calling (*net.ListenConfig).ListenPacket, then conditionally
upgrade it to a batching.Conn on platforms that support it. This works on Linux, where we can
upgrade an existing net.UDPConn to support batching I/O, but it does not work as well on other
platforms where batching may need to be implemented in terms of a platform-specific APIs.

Also, since in practice all nettype.PacketConn implementations were net.UDPConn (at least
temporarily before the upgrade), we used type assertions to *net.UDPConn even when any
syscall.Conn (a type with a SyscallConn method) would work.

In this PR, as preparation for implementing batching.Conn for Windows, we replace those type
assertions with interfaces and add batching.PacketListener. The default implementation creates
a net.UDPConn and tries to upgrade it to a batching.Conn, while allowing platforms to override
ListenPacket as needed.

We then unexport batching.TryUpgradeToConn and replace its usage with batching.PacketListener.

Updates tailscale/corp#36208
This commit is contained in:
Nick Khyl
2026-01-22 11:11:16 -06:00
parent 532662e701
commit 47d3255bd3
13 changed files with 98 additions and 42 deletions

View File

@@ -9,8 +9,8 @@
"tailscale.com/types/nettype"
)
// TryUpgradeToConn is no-op on all platforms except linux.
func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
// tryUpgradeToConn is no-op on all platforms except linux.
func tryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn {
return pconn
}

View File

@@ -46,6 +46,9 @@ type xnetBatchWriter interface {
var (
// [linuxBatchingConn] implements [Conn].
_ Conn = (*linuxBatchingConn)(nil)
// [linuxBatchingConn] implements [syscall.Conn].
_ syscall.Conn = (*linuxBatchingConn)(nil)
)
// linuxBatchingConn is a UDP socket that provides batched i/o. It implements
@@ -383,10 +386,10 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) {
*control = (*control)[:unix.CmsgSpace(2)]
}
// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
// tryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades
// pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is
// suggested for the best performance.
func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
func tryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn {
if runtime.GOOS != "linux" {
// Exclude Android.
return pconn

49
net/batching/listener.go Normal file
View File

@@ -0,0 +1,49 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package batching
import (
"context"
"net"
"tailscale.com/types/nettype"
)
var listenPacket = listenPacketStd
var _ nettype.PacketListenerWithNetIP = (*PacketListener)(nil)
// PacketListener is a [nettype.PacketListenerWithNetIP] implementation that
// creates packet connections optimized for high throughput on platforms that
// support batched I/O.
type PacketListener struct {
config *net.ListenConfig
batchSize int
}
// NewPacketListener returns a new [PacketListener] that uses the provided
// [net.ListenConfig] to configure new connections, and attempts to enable
// batched I/O with the provided batchSize if supported on the current platform.
func NewPacketListener(config *net.ListenConfig, batchSize int) nettype.PacketListenerWithNetIP {
return &PacketListener{config, batchSize}
}
// ListenPacket implements [nettype.PacketListenerWithNetIP].
// On platforms that support batched I/O, the returned [nettype.PacketConn]
// is a [Conn].
func (pl *PacketListener) ListenPacket(ctx context.Context, network, address string) (nettype.PacketConn, error) {
return listenPacket(ctx, network, address, pl.config, pl.batchSize)
}
var _ nettype.PacketConn = (*net.UDPConn)(nil)
// listenPacketStd creates a [net.UDPConn] and attempts to upgrade it to
// a [Conn] if supported on the current platform (as of 2026-01-22, only Linux).
func listenPacketStd(ctx context.Context, network, address string, config *net.ListenConfig, batchSize int) (nettype.PacketConn, error) {
conn, err := config.ListenPacket(ctx, network, address)
if err != nil {
return nil, err
}
return tryUpgradeToConn(conn.(nettype.PacketConn), network, batchSize), nil
}

View File

@@ -20,13 +20,20 @@
WriteDirection BufferDirection = "write"
)
type bufferedConn interface {
SetReadBuffer(bytes int) error
SetWriteBuffer(bytes int) error
}
var _ bufferedConn = (*net.UDPConn)(nil)
func portableSetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) error {
if runtime.GOOS == "plan9" {
// Not supported. Don't try. Avoid logspam.
return nil
}
var err error
if c, ok := pconn.(*net.UDPConn); ok {
if c, ok := pconn.(bufferedConn); ok {
if direction == WriteDirection {
err = c.SetWriteBuffer(size)
} else {

View File

@@ -15,7 +15,7 @@
// errForce is only relevant for Linux, and will always be nil otherwise,
// but we maintain a consistent cross-platform API.
//
// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op.
// If pconn does not support setting buffer sizes, then SetBufferSize is no-op.
func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) {
return nil, portableSetBufferSize(pconn, direction, size)
}

View File

@@ -6,7 +6,6 @@
package sockopts
import (
"net"
"syscall"
"tailscale.com/types/nettype"
@@ -18,13 +17,13 @@
// the portable implementation (errPortable) if that fails, which may be
// silently capped to net.core.{r,w}mem_max.
//
// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op.
// If pconn does not support setting buffer sizes, then SetBufferSize is no-op.
func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) {
opt := syscall.SO_RCVBUFFORCE
if direction == WriteDirection {
opt = syscall.SO_SNDBUFFORCE
}
if c, ok := pconn.(*net.UDPConn); ok {
if c, ok := pconn.(syscall.Conn); ok {
var rc syscall.RawConn
rc, errForce = c.SyscallConn()
if errForce == nil {

View File

@@ -7,7 +7,7 @@
import (
"fmt"
"net"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
@@ -17,9 +17,9 @@
// SetICMPErrImmunity sets socket options on pconn to prevent ICMP reception,
// e.g. ICMP Port Unreachable, from surfacing as a syscall error.
//
// If pconn is not a [*net.UDPConn], then SetICMPErrImmunity is no-op.
// If pconn is not a [syscall.Conn], then SetICMPErrImmunity is no-op.
func SetICMPErrImmunity(pconn nettype.PacketConn) error {
c, ok := pconn.(*net.UDPConn)
c, ok := pconn.(syscall.Conn)
if !ok {
// not a UDP connection; nothing to do
return nil

View File

@@ -514,11 +514,11 @@ func (s *Server) addrDiscoveryLoop() {
// singlePacketConn implements [batching.Conn] with single packet syscall
// operations.
type singlePacketConn struct {
*net.UDPConn
nettype.PacketConn
}
func (c *singlePacketConn) ReadBatch(msgs []ipv6.Message, _ int) (int, error) {
n, ap, err := c.UDPConn.ReadFromUDPAddrPort(msgs[0].Buffers[0])
n, ap, err := c.PacketConn.ReadFromUDPAddrPort(msgs[0].Buffers[0])
if err != nil {
return 0, err
}
@@ -534,7 +534,7 @@ func (c *singlePacketConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, gen
} else {
buff = buff[offset:]
}
_, err := c.UDPConn.WriteToUDPAddrPort(buff, addr)
_, err := c.PacketConn.WriteToUDPAddrPort(buff, addr)
if err != nil {
return err
}
@@ -623,7 +623,7 @@ func (s *Server) bindSockets(desiredPort uint16) error {
desiredPort = s.uc6Port
}
}
uc, boundPort, err := s.bindSocketTo(listenConfig, network, desiredPort)
uc, boundPort, err := s.bindSocketTo(listenConfig, network, desiredPort, batching.IdealBatchSize)
if err != nil {
switch {
case i == 0 && network == "udp4":
@@ -639,8 +639,7 @@ func (s *Server) bindSockets(desiredPort uint16) error {
break SocketsLoop
}
}
pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize)
bc, ok := pc.(batching.Conn)
bc, ok := uc.(batching.Conn)
if !ok {
bc = &singlePacketConn{uc}
}
@@ -663,12 +662,11 @@ func (s *Server) bindSockets(desiredPort uint16) error {
return nil
}
func (s *Server) bindSocketTo(listenConfig *net.ListenConfig, network string, port uint16) (*net.UDPConn, uint16, error) {
lis, err := listenConfig.ListenPacket(context.Background(), network, fmt.Sprintf(":%d", port))
func (s *Server) bindSocketTo(listenConfig *net.ListenConfig, network string, port uint16, batchSize int) (nettype.PacketConn, uint16, error) {
uc, err := batching.NewPacketListener(listenConfig, batchSize).ListenPacket(context.Background(), network, fmt.Sprintf(":%d", port))
if err != nil {
return nil, 0, err
}
uc := lis.(*net.UDPConn)
trySetUDPSocketOptions(uc, s.logf)
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
if err != nil {

View File

@@ -6,10 +6,10 @@
package udprelay
import (
"net"
"syscall"
"golang.org/x/sys/unix"
"tailscale.com/types/nettype"
)
func trySetReusePort(_ string, _ string, c syscall.RawConn) {
@@ -18,8 +18,12 @@ func trySetReusePort(_ string, _ string, c syscall.RawConn) {
})
}
func isReusableSocket(uc *net.UDPConn) bool {
rc, err := uc.SyscallConn()
func isReusableSocket(pc nettype.PacketConn) bool {
sc, ok := pc.(syscall.Conn)
if !ok {
return false
}
rc, err := sc.SyscallConn()
if err != nil {
return false
}

View File

@@ -6,12 +6,13 @@
package udprelay
import (
"net"
"syscall"
"tailscale.com/types/nettype"
)
func trySetReusePort(_ string, _ string, _ syscall.RawConn) {}
func isReusableSocket(*net.UDPConn) bool {
func isReusableSocket(nettype.PacketConn) bool {
return false
}

View File

@@ -3507,7 +3507,7 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er
if c.testOnlyPacketListener != nil {
return nettype.MakePacketListenerWithNetIP(c.testOnlyPacketListener).ListenPacket(ctx, network, addr)
}
return nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, c.netMon)).ListenPacket(ctx, network, addr)
return batching.NewPacketListener(netns.Listener(c.logf, c.netMon), c.bind.BatchSize()).ListenPacket(ctx, network, addr)
}
// bindSocket binds a UDP socket to ruc.
@@ -3527,13 +3527,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
defer ruc.mu.Unlock()
if runtime.GOOS == "js" {
ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize())
ruc.setConnLocked(newBlockForeverConn())
return nil
}
if debugAlwaysDERP() {
c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network)
ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize())
ruc.setConnLocked(newBlockForeverConn())
return nil
}
@@ -3592,7 +3592,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
if debugBindSocket() {
c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port)
}
ruc.setConnLocked(pconn, network, c.bind.BatchSize())
ruc.setConnLocked(pconn)
if network == "udp4" {
c.health.SetUDP4Unbound(false)
}
@@ -3603,7 +3603,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur
// Set pconn to a dummy conn whose reads block until closed.
// This keeps the receive funcs alive for a future in which
// we get a link change and we can try binding again.
ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize())
ruc.setConnLocked(newBlockForeverConn())
if network == "udp4" {
c.health.SetUDP4Unbound(true)
}

View File

@@ -2108,8 +2108,8 @@ func TestRebindingUDPConn(t *testing.T) {
t.Fatal(err)
}
defer realConn.Close()
c.setConnLocked(realConn.(nettype.PacketConn), "udp4", 1)
c.setConnLocked(newBlockForeverConn(), "", 1)
c.setConnLocked(realConn.(nettype.PacketConn))
c.setConnLocked(newBlockForeverConn())
}
// https://github.com/tailscale/tailscale/issues/6680: don't ignore

View File

@@ -37,15 +37,10 @@ type RebindingUDPConn struct {
}
// setConnLocked sets the provided nettype.PacketConn. It should be called only
// after acquiring RebindingUDPConn.mu. It upgrades the provided
// nettype.PacketConn to a batchingConn when appropriate. This upgrade is
// intentionally pushed closest to where read/write ops occur in order to avoid
// disrupting surrounding code that assumes nettype.PacketConn is a
// *net.UDPConn.
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) {
upc := batching.TryUpgradeToConn(p, network, batchSize)
c.pconn = upc
c.pconnAtomic.Store(&upc)
// after acquiring RebindingUDPConn.mu.
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) {
c.pconn = p
c.pconnAtomic.Store(&p)
c.port = uint16(c.localAddrLocked().Port)
}