diff --git a/net/batching/conn_default.go b/net/batching/conn_default.go index 37d644f50..9d3bda9bf 100644 --- a/net/batching/conn_default.go +++ b/net/batching/conn_default.go @@ -9,8 +9,8 @@ "tailscale.com/types/nettype" ) -// TryUpgradeToConn is no-op on all platforms except linux. -func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn { +// tryUpgradeToConn is no-op on all platforms except linux. +func tryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn { return pconn } diff --git a/net/batching/conn_linux.go b/net/batching/conn_linux.go index bd7ac25be..69fa0fdee 100644 --- a/net/batching/conn_linux.go +++ b/net/batching/conn_linux.go @@ -46,6 +46,9 @@ type xnetBatchWriter interface { var ( // [linuxBatchingConn] implements [Conn]. _ Conn = (*linuxBatchingConn)(nil) + + // [linuxBatchingConn] implements [syscall.Conn]. + _ syscall.Conn = (*linuxBatchingConn)(nil) ) // linuxBatchingConn is a UDP socket that provides batched i/o. It implements @@ -383,10 +386,10 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) { *control = (*control)[:unix.CmsgSpace(2)] } -// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades +// tryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades // pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is // suggested for the best performance. -func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { +func tryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { if runtime.GOOS != "linux" { // Exclude Android. return pconn diff --git a/net/batching/listener.go b/net/batching/listener.go new file mode 100644 index 000000000..00ced3cd5 --- /dev/null +++ b/net/batching/listener.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package batching + +import ( + "context" + "net" + + "tailscale.com/types/nettype" +) + +var listenPacket = listenPacketStd + +var _ nettype.PacketListenerWithNetIP = (*PacketListener)(nil) + +// PacketListener is a [nettype.PacketListenerWithNetIP] implementation that +// creates packet connections optimized for high throughput on platforms that +// support batched I/O. +type PacketListener struct { + config *net.ListenConfig + batchSize int +} + +// NewPacketListener returns a new [PacketListener] that uses the provided +// [net.ListenConfig] to configure new connections, and attempts to enable +// batched I/O with the provided batchSize if supported on the current platform. +func NewPacketListener(config *net.ListenConfig, batchSize int) nettype.PacketListenerWithNetIP { + return &PacketListener{config, batchSize} +} + +// ListenPacket implements [nettype.PacketListenerWithNetIP]. +// On platforms that support batched I/O, the returned [nettype.PacketConn] +// is a [Conn]. +func (pl *PacketListener) ListenPacket(ctx context.Context, network, address string) (nettype.PacketConn, error) { + return listenPacket(ctx, network, address, pl.config, pl.batchSize) +} + +var _ nettype.PacketConn = (*net.UDPConn)(nil) + +// listenPacketStd creates a [net.UDPConn] and attempts to upgrade it to +// a [Conn] if supported on the current platform (as of 2026-01-22, only Linux). +func listenPacketStd(ctx context.Context, network, address string, config *net.ListenConfig, batchSize int) (nettype.PacketConn, error) { + conn, err := config.ListenPacket(ctx, network, address) + if err != nil { + return nil, err + } + return tryUpgradeToConn(conn.(nettype.PacketConn), network, batchSize), nil +} diff --git a/net/sockopts/sockopts.go b/net/sockopts/sockopts.go index 0c0ee7692..41b8371aa 100644 --- a/net/sockopts/sockopts.go +++ b/net/sockopts/sockopts.go @@ -20,13 +20,20 @@ WriteDirection BufferDirection = "write" ) +type bufferedConn interface { + SetReadBuffer(bytes int) error + SetWriteBuffer(bytes int) error +} + +var _ bufferedConn = (*net.UDPConn)(nil) + func portableSetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) error { if runtime.GOOS == "plan9" { // Not supported. Don't try. Avoid logspam. return nil } var err error - if c, ok := pconn.(*net.UDPConn); ok { + if c, ok := pconn.(bufferedConn); ok { if direction == WriteDirection { err = c.SetWriteBuffer(size) } else { diff --git a/net/sockopts/sockopts_default.go b/net/sockopts/sockopts_default.go index 3cc8679b5..8ca2cb366 100644 --- a/net/sockopts/sockopts_default.go +++ b/net/sockopts/sockopts_default.go @@ -15,7 +15,7 @@ // errForce is only relevant for Linux, and will always be nil otherwise, // but we maintain a consistent cross-platform API. // -// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op. +// If pconn does not support setting buffer sizes, then SetBufferSize is no-op. func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) { return nil, portableSetBufferSize(pconn, direction, size) } diff --git a/net/sockopts/sockopts_linux.go b/net/sockopts/sockopts_linux.go index 5d778d380..f93cf4a66 100644 --- a/net/sockopts/sockopts_linux.go +++ b/net/sockopts/sockopts_linux.go @@ -6,7 +6,6 @@ package sockopts import ( - "net" "syscall" "tailscale.com/types/nettype" @@ -18,13 +17,13 @@ // the portable implementation (errPortable) if that fails, which may be // silently capped to net.core.{r,w}mem_max. // -// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op. +// If pconn does not support setting buffer sizes, then SetBufferSize is no-op. func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) { opt := syscall.SO_RCVBUFFORCE if direction == WriteDirection { opt = syscall.SO_SNDBUFFORCE } - if c, ok := pconn.(*net.UDPConn); ok { + if c, ok := pconn.(syscall.Conn); ok { var rc syscall.RawConn rc, errForce = c.SyscallConn() if errForce == nil { diff --git a/net/sockopts/sockopts_windows.go b/net/sockopts/sockopts_windows.go index 1e6c3f69d..a18c4727a 100644 --- a/net/sockopts/sockopts_windows.go +++ b/net/sockopts/sockopts_windows.go @@ -7,7 +7,7 @@ import ( "fmt" - "net" + "syscall" "unsafe" "golang.org/x/sys/windows" @@ -17,9 +17,9 @@ // SetICMPErrImmunity sets socket options on pconn to prevent ICMP reception, // e.g. ICMP Port Unreachable, from surfacing as a syscall error. // -// If pconn is not a [*net.UDPConn], then SetICMPErrImmunity is no-op. +// If pconn is not a [syscall.Conn], then SetICMPErrImmunity is no-op. func SetICMPErrImmunity(pconn nettype.PacketConn) error { - c, ok := pconn.(*net.UDPConn) + c, ok := pconn.(syscall.Conn) if !ok { // not a UDP connection; nothing to do return nil diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 2b6d38923..661ebe1e5 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -514,11 +514,11 @@ func (s *Server) addrDiscoveryLoop() { // singlePacketConn implements [batching.Conn] with single packet syscall // operations. type singlePacketConn struct { - *net.UDPConn + nettype.PacketConn } func (c *singlePacketConn) ReadBatch(msgs []ipv6.Message, _ int) (int, error) { - n, ap, err := c.UDPConn.ReadFromUDPAddrPort(msgs[0].Buffers[0]) + n, ap, err := c.PacketConn.ReadFromUDPAddrPort(msgs[0].Buffers[0]) if err != nil { return 0, err } @@ -534,7 +534,7 @@ func (c *singlePacketConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, gen } else { buff = buff[offset:] } - _, err := c.UDPConn.WriteToUDPAddrPort(buff, addr) + _, err := c.PacketConn.WriteToUDPAddrPort(buff, addr) if err != nil { return err } @@ -623,7 +623,7 @@ func (s *Server) bindSockets(desiredPort uint16) error { desiredPort = s.uc6Port } } - uc, boundPort, err := s.bindSocketTo(listenConfig, network, desiredPort) + uc, boundPort, err := s.bindSocketTo(listenConfig, network, desiredPort, batching.IdealBatchSize) if err != nil { switch { case i == 0 && network == "udp4": @@ -639,8 +639,7 @@ func (s *Server) bindSockets(desiredPort uint16) error { break SocketsLoop } } - pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize) - bc, ok := pc.(batching.Conn) + bc, ok := uc.(batching.Conn) if !ok { bc = &singlePacketConn{uc} } @@ -663,12 +662,11 @@ func (s *Server) bindSockets(desiredPort uint16) error { return nil } -func (s *Server) bindSocketTo(listenConfig *net.ListenConfig, network string, port uint16) (*net.UDPConn, uint16, error) { - lis, err := listenConfig.ListenPacket(context.Background(), network, fmt.Sprintf(":%d", port)) +func (s *Server) bindSocketTo(listenConfig *net.ListenConfig, network string, port uint16, batchSize int) (nettype.PacketConn, uint16, error) { + uc, err := batching.NewPacketListener(listenConfig, batchSize).ListenPacket(context.Background(), network, fmt.Sprintf(":%d", port)) if err != nil { return nil, 0, err } - uc := lis.(*net.UDPConn) trySetUDPSocketOptions(uc, s.logf) _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) if err != nil { diff --git a/net/udprelay/server_linux.go b/net/udprelay/server_linux.go index d4cf2a2b1..106efb071 100644 --- a/net/udprelay/server_linux.go +++ b/net/udprelay/server_linux.go @@ -6,10 +6,10 @@ package udprelay import ( - "net" "syscall" "golang.org/x/sys/unix" + "tailscale.com/types/nettype" ) func trySetReusePort(_ string, _ string, c syscall.RawConn) { @@ -18,8 +18,12 @@ func trySetReusePort(_ string, _ string, c syscall.RawConn) { }) } -func isReusableSocket(uc *net.UDPConn) bool { - rc, err := uc.SyscallConn() +func isReusableSocket(pc nettype.PacketConn) bool { + sc, ok := pc.(syscall.Conn) + if !ok { + return false + } + rc, err := sc.SyscallConn() if err != nil { return false } diff --git a/net/udprelay/server_notlinux.go b/net/udprelay/server_notlinux.go index f21020631..8a2ae7c61 100644 --- a/net/udprelay/server_notlinux.go +++ b/net/udprelay/server_notlinux.go @@ -6,12 +6,13 @@ package udprelay import ( - "net" "syscall" + + "tailscale.com/types/nettype" ) func trySetReusePort(_ string, _ string, _ syscall.RawConn) {} -func isReusableSocket(*net.UDPConn) bool { +func isReusableSocket(nettype.PacketConn) bool { return false } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 8fbd07013..d6741e938 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -3507,7 +3507,7 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er if c.testOnlyPacketListener != nil { return nettype.MakePacketListenerWithNetIP(c.testOnlyPacketListener).ListenPacket(ctx, network, addr) } - return nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, c.netMon)).ListenPacket(ctx, network, addr) + return batching.NewPacketListener(netns.Listener(c.logf, c.netMon), c.bind.BatchSize()).ListenPacket(ctx, network, addr) } // bindSocket binds a UDP socket to ruc. @@ -3527,13 +3527,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur defer ruc.mu.Unlock() if runtime.GOOS == "js" { - ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) + ruc.setConnLocked(newBlockForeverConn()) return nil } if debugAlwaysDERP() { c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network) - ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) + ruc.setConnLocked(newBlockForeverConn()) return nil } @@ -3592,7 +3592,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur if debugBindSocket() { c.logf("magicsock: bindSocket: successfully listened %v port %d", network, port) } - ruc.setConnLocked(pconn, network, c.bind.BatchSize()) + ruc.setConnLocked(pconn) if network == "udp4" { c.health.SetUDP4Unbound(false) } @@ -3603,7 +3603,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur // Set pconn to a dummy conn whose reads block until closed. // This keeps the receive funcs alive for a future in which // we get a link change and we can try binding again. - ruc.setConnLocked(newBlockForeverConn(), "", c.bind.BatchSize()) + ruc.setConnLocked(newBlockForeverConn()) if network == "udp4" { c.health.SetUDP4Unbound(true) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 68ab4dfa0..42f25001b 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -2108,8 +2108,8 @@ func TestRebindingUDPConn(t *testing.T) { t.Fatal(err) } defer realConn.Close() - c.setConnLocked(realConn.(nettype.PacketConn), "udp4", 1) - c.setConnLocked(newBlockForeverConn(), "", 1) + c.setConnLocked(realConn.(nettype.PacketConn)) + c.setConnLocked(newBlockForeverConn()) } // https://github.com/tailscale/tailscale/issues/6680: don't ignore diff --git a/wgengine/magicsock/rebinding_conn.go b/wgengine/magicsock/rebinding_conn.go index c98e64570..88a1e983c 100644 --- a/wgengine/magicsock/rebinding_conn.go +++ b/wgengine/magicsock/rebinding_conn.go @@ -37,15 +37,10 @@ type RebindingUDPConn struct { } // setConnLocked sets the provided nettype.PacketConn. It should be called only -// after acquiring RebindingUDPConn.mu. It upgrades the provided -// nettype.PacketConn to a batchingConn when appropriate. This upgrade is -// intentionally pushed closest to where read/write ops occur in order to avoid -// disrupting surrounding code that assumes nettype.PacketConn is a -// *net.UDPConn. -func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { - upc := batching.TryUpgradeToConn(p, network, batchSize) - c.pconn = upc - c.pconnAtomic.Store(&upc) +// after acquiring RebindingUDPConn.mu. +func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) { + c.pconn = p + c.pconnAtomic.Store(&p) c.port = uint16(c.localAddrLocked().Port) }