feature/conn25: install all the hooks

Install the previously uninstalled hooks for the filter and tstun
intercepts. Move the DNS manager hook installation into Init() with all
the others. Protect all implementations with a short-circuit if the node
is not configured to use Connectors 2025. The short-circuit pattern
replaces the previous pattern used in managing the DNS manager hook, of
setting it to nil in response to CapMap changes.

Fixes tailscale/corp#38716

Signed-off-by: Michael Ben-Ami <mzb@tailscale.com>
This commit is contained in:
Michael Ben-Ami
2026-03-26 16:38:04 -04:00
committed by mzbenami
parent 70fabf1716
commit 156e6ae5cd
4 changed files with 131 additions and 76 deletions

View File

@@ -27,9 +27,9 @@
"tailscale.com/feature"
"tailscale.com/ipn/ipnext"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/net/dns"
"tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tstun"
"tailscale.com/tailcfg"
"tailscale.com/types/appctype"
"tailscale.com/types/key"
@@ -39,6 +39,7 @@
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/testenv"
"tailscale.com/wgengine/filter"
)
// featureName is the name of the feature implemented by this package.
@@ -100,9 +101,6 @@ type extension struct {
host ipnext.Host // set in Init, read-only after
ctxCancel context.CancelCauseFunc // cancels sendLoop goroutine
mu sync.Mutex // protects the fields below
isDNSHookRegistered bool
}
// Name implements [ipnext.Extension].
@@ -117,23 +115,118 @@ func (e *extension) Init(host ipnext.Host) error {
return ipnext.SkipExtension
}
//Init only once
e.mu.Lock()
defer e.mu.Unlock()
if e.ctxCancel != nil {
return nil
}
e.host = host
host.Hooks().OnSelfChange.Add(e.onSelfChange)
host.Hooks().ExtraRouterConfigRoutes.Set(e.getMagicRange)
host.Hooks().ExtraWireGuardAllowedIPs.Set(e.extraWireGuardAllowedIPs)
dph := newDatapathHandler(e.conn25, e.conn25.client.logf)
if err := e.installHooks(dph); err != nil {
return err
}
ctx, cancel := context.WithCancelCause(context.Background())
e.ctxCancel = cancel
go e.sendLoop(ctx)
return nil
}
func (e *extension) installHooks(dph *datapathHandler) error {
// Make sure we can access the DNS manager and the system tun.
dnsManager, ok := e.backend.Sys().DNSManager.GetOK()
if !ok {
return errors.New("could not access system dns manager")
}
tun, ok := e.backend.Sys().Tun.GetOK()
if !ok {
return errors.New("could not access system tun")
}
// Set up the DNS manager to rewrite responses for app domains
// to answer with Magic IPs.
dnsManager.SetQueryResponseMapper(func(bs []byte) []byte {
if !e.conn25.isConfigured() {
return bs
}
return e.conn25.mapDNSResponse(bs)
})
// Intercept packets from the tun device and from WireGuard
// to perform DNAT and SNAT.
tun.PreFilterPacketOutboundToWireGuardAppConnectorIntercept = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response {
if !e.conn25.isConfigured() {
return filter.Accept
}
return dph.HandlePacketFromTunDevice(p)
}
tun.PostFilterPacketInboundFromWireGuardAppConnector = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response {
if !e.conn25.isConfigured() {
return filter.Accept
}
return dph.HandlePacketFromWireGuard(p)
}
// Manage how we react to changes to the current node,
// including property changes (e.g. HostInfo, Capabilities, CapMap)
// and profile switches.
e.host.Hooks().OnSelfChange.Add(e.onSelfChange)
// Allow the client to send packets with Transit IP destinations
// in the link-local space.
e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) {
if !e.conn25.isConfigured() {
return false, ""
}
return e.conn25.client.linkLocalAllow(p)
})
// Allow the connector to receive packets with Transit IP destinations
// in the link-local space.
e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) {
if !e.conn25.isConfigured() {
return false, ""
}
return e.conn25.connector.packetFilterAllow(p)
})
// Allow the connector to receive packets with Transit IP destinations
// that are not "local" to it, and that it does not advertise.
e.host.Hooks().Filter.IngressAllowHooks.Add(func(p packet.Parsed) (bool, string) {
if !e.conn25.isConfigured() {
return false, ""
}
return e.conn25.connector.packetFilterAllow(p)
})
// Give the client the Magic IP range to install on the OS.
e.host.Hooks().ExtraRouterConfigRoutes.Set(func() views.Slice[netip.Prefix] {
if !e.conn25.isConfigured() {
return views.Slice[netip.Prefix]{}
}
return e.getMagicRange()
})
// Tell WireGuard what Transit IPs belong to which connector peers.
e.host.Hooks().ExtraWireGuardAllowedIPs.Set(func(k key.NodePublic) views.Slice[netip.Prefix] {
if !e.conn25.isConfigured() {
return views.Slice[netip.Prefix]{}
}
return e.extraWireGuardAllowedIPs(k)
})
return nil
}
// ClientTransitIPForMagicIP implements [IPMapper].
func (c *Conn25) ClientTransitIPForMagicIP(m netip.Addr) (netip.Addr, error) {
return c.client.transitIPForMagicIP(m)
}
// ConnectorRealIPForTransitIPConnection implements [IPMapper].
func (c *Conn25) ConnectorRealIPForTransitIPConnection(src, transit netip.Addr) (netip.Addr, error) {
return c.connector.realIPForTransitIPConnection(src, transit)
}
func (e *extension) getMagicRange() views.Slice[netip.Prefix] {
cfg := e.conn25.client.getConfig()
return views.SliceOf(cfg.magicIPSet.Prefixes())
@@ -177,56 +270,12 @@ func (e *extension) onSelfChange(selfNode tailcfg.NodeView) {
e.conn25.client.logf("error during Reconfig onSelfChange: %v", err)
return
}
if e.conn25.isConfigured() {
err = e.registerDNSHook()
} else {
err = e.unregisterDNSHook()
}
if err != nil {
e.conn25.client.logf("error managing DNS hook onSelfChange: %v", err)
}
}
func (e *extension) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] {
return e.conn25.client.extraWireGuardAllowedIPs(k)
}
func (e *extension) registerDNSHook() error {
e.mu.Lock()
defer e.mu.Unlock()
if e.isDNSHookRegistered {
return nil
}
err := e.setDNSHookLocked(e.conn25.mapDNSResponse)
if err == nil {
e.isDNSHookRegistered = true
}
return err
}
func (e *extension) unregisterDNSHook() error {
e.mu.Lock()
defer e.mu.Unlock()
if !e.isDNSHookRegistered {
return nil
}
err := e.setDNSHookLocked(nil)
if err == nil {
e.isDNSHookRegistered = false
}
return err
}
func (e *extension) setDNSHookLocked(fx dns.ResponseMapper) error {
dnsManager, ok := e.backend.Sys().DNSManager.GetOK()
if !ok || dnsManager == nil {
return errors.New("couldn't get DNSManager from sys")
}
dnsManager.SetQueryResponseMapper(fx)
return nil
}
type appAddr struct {
app string
addr netip.Addr
@@ -517,8 +566,9 @@ func (c *client) getConfig() config {
return c.config
}
// ClientTransitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups.
func (c *client) ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) {
// transitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups.
// See also [IPMapper.ClientTransitIPForMagicIP].
func (c *client) transitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
v, ok := c.assignments.lookupByMagicIP(magicIP)
@@ -938,8 +988,9 @@ type connector struct {
config config
}
// ConnectorRealIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups.
func (c *connector) ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) {
// realIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups.
// See also [IPMapper.ConnectorRealIPForTransitIPConnection].
func (c *connector) realIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
v, ok := c.lookupBySrcIPAndTransitIP(srcIP, transitIP)

View File

@@ -18,8 +18,10 @@
"go4.org/netipx"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/ipn/ipnext"
"tailscale.com/net/dns"
"tailscale.com/net/packet"
"tailscale.com/net/tsdial"
"tailscale.com/net/tstun"
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/types/appctype"
@@ -818,6 +820,16 @@ type testSafeBackend struct {
sys *tsd.System
}
func newTestSafeBackend() *testSafeBackend {
sb := &testSafeBackend{}
sys := &tsd.System{}
sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard})
sys.DNSManager.Set(&dns.Manager{})
sys.Tun.Set(&tstun.Wrapper{})
sb.sys = sys
return sb
}
func (b *testSafeBackend) Sys() *tsd.System { return b.sys }
// TestAddressAssignmentIsHandled tests that after enqueueAddress has been called
@@ -852,13 +864,9 @@ func TestAddressAssignmentIsHandled(t *testing.T) {
Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 1: 0xff, 31: 0x01})),
}).View()
// make extension to test
sys := &tsd.System{}
sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard})
ext := &extension{
conn25: newConn25(logger.Discard),
backend: &testSafeBackend{sys: sys},
backend: newTestSafeBackend(),
}
authReconfigAsyncCalled := make(chan struct{}, 1)
if err := ext.Init(&testHost{
@@ -1243,13 +1251,9 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) {
}).View(),
}
// make extension to test
sys := &tsd.System{}
sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard})
ext := &extension{
conn25: newConn25(logger.Discard),
backend: &testSafeBackend{sys: sys},
backend: newTestSafeBackend(),
}
authReconfigAsyncCalled := make(chan struct{}, 1)
if err := ext.Init(&testHost{
@@ -1494,7 +1498,7 @@ func TestClientTransitIPForMagicIP(t *testing.T) {
magic: mappedMip,
transit: mappedTip,
})
tip, err := c.client.ClientTransitIPForMagicIP(tt.mip)
tip, err := c.client.transitIPForMagicIP(tt.mip)
if tip != tt.wantTip {
t.Fatalf("checking transit ip: want %v, got %v", tt.wantTip, tip)
}
@@ -1566,7 +1570,7 @@ func TestConnectorRealIPForTransitIPConnection(t *testing.T) {
c.connector.transitIPs = map[netip.Addr]map[netip.Addr]appAddr{}
c.connector.transitIPs[mappedSrc] = map[netip.Addr]appAddr{}
c.connector.transitIPs[mappedSrc][mappedTip] = appAddr{addr: mappedMip}
mip, err := c.connector.ConnectorRealIPForTransitIPConnection(tt.src, tt.tip)
mip, err := c.connector.realIPForTransitIPConnection(tt.src, tt.tip)
if mip != tt.wantMip {
t.Fatalf("checking magic ip: want %v, got %v", tt.wantMip, mip)
}

View File

@@ -37,8 +37,8 @@ type IPMapper interface {
// range for an app on the connector, but not mapped to the client at srcIP, implementations
// should return [ErrUnmappedSrcAndTransitIP]. If the transitIP is not within a configured
// Transit IP range, i.e. it is not actually a Transit IP, implementations should return
// a nil error, a zero-value [netip.Addr] to indicate this is potentially valid, non-app-connector
// traffic.
// a nil error, and a zero-value [netip.Addr] to indicate this is potentially valid,
// non-app-connector traffic.
ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error)
}

View File

@@ -112,7 +112,7 @@ func TestHandlePacketFromTunDevice(t *testing.T) {
}
return netip.Addr{}, nil
}
dph := newDatapathHandler(mock, nil)
dph := newDatapathHandler(mock, t.Logf)
tt.p.IPProto = ipproto.UDP
tt.p.IPVersion = 4
@@ -217,7 +217,7 @@ func TestHandlePacketFromWireGuard(t *testing.T) {
}
return netip.Addr{}, nil
}
dph := newDatapathHandler(mock, nil)
dph := newDatapathHandler(mock, t.Logf)
tt.p.IPProto = ipproto.UDP
tt.p.IPVersion = 4
@@ -254,7 +254,7 @@ func TestClientFlowCache(t *testing.T) {
getTransitIPCalled = true
return transitIP, nil
}
dph := newDatapathHandler(mock, nil)
dph := newDatapathHandler(mock, t.Logf)
outgoing := packet.Parsed{
IPProto: ipproto.UDP,
@@ -316,7 +316,7 @@ func TestConnectorFlowCache(t *testing.T) {
getRealIPCalled = true
return realIP, nil
}
dph := newDatapathHandler(mock, nil)
dph := newDatapathHandler(mock, t.Logf)
outgoing := packet.Parsed{
IPProto: ipproto.UDP,