From 156e6ae5cd37f77ad0726b313105158c818779dc Mon Sep 17 00:00:00 2001 From: Michael Ben-Ami Date: Thu, 26 Mar 2026 16:38:04 -0400 Subject: [PATCH] feature/conn25: install all the hooks Install the previously uninstalled hooks for the filter and tstun intercepts. Move the DNS manager hook installation into Init() with all the others. Protect all implementations with a short-circuit if the node is not configured to use Connectors 2025. The short-circuit pattern replaces the previous pattern used in managing the DNS manager hook, of setting it to nil in response to CapMap changes. Fixes tailscale/corp#38716 Signed-off-by: Michael Ben-Ami --- feature/conn25/conn25.go | 167 +++++++++++++++++++++----------- feature/conn25/conn25_test.go | 28 +++--- feature/conn25/datapath.go | 4 +- feature/conn25/datapath_test.go | 8 +- 4 files changed, 131 insertions(+), 76 deletions(-) diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index 98af443d4..eeb02c5f8 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -27,9 +27,9 @@ "tailscale.com/feature" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" - "tailscale.com/net/dns" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/types/appctype" "tailscale.com/types/key" @@ -39,6 +39,7 @@ "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/testenv" + "tailscale.com/wgengine/filter" ) // featureName is the name of the feature implemented by this package. @@ -100,9 +101,6 @@ type extension struct { host ipnext.Host // set in Init, read-only after ctxCancel context.CancelCauseFunc // cancels sendLoop goroutine - - mu sync.Mutex // protects the fields below - isDNSHookRegistered bool } // Name implements [ipnext.Extension]. @@ -117,23 +115,118 @@ func (e *extension) Init(host ipnext.Host) error { return ipnext.SkipExtension } - //Init only once - e.mu.Lock() - defer e.mu.Unlock() if e.ctxCancel != nil { return nil } e.host = host - host.Hooks().OnSelfChange.Add(e.onSelfChange) - host.Hooks().ExtraRouterConfigRoutes.Set(e.getMagicRange) - host.Hooks().ExtraWireGuardAllowedIPs.Set(e.extraWireGuardAllowedIPs) + dph := newDatapathHandler(e.conn25, e.conn25.client.logf) + if err := e.installHooks(dph); err != nil { + return err + } + ctx, cancel := context.WithCancelCause(context.Background()) e.ctxCancel = cancel go e.sendLoop(ctx) return nil } +func (e *extension) installHooks(dph *datapathHandler) error { + // Make sure we can access the DNS manager and the system tun. + dnsManager, ok := e.backend.Sys().DNSManager.GetOK() + if !ok { + return errors.New("could not access system dns manager") + } + tun, ok := e.backend.Sys().Tun.GetOK() + if !ok { + return errors.New("could not access system tun") + } + + // Set up the DNS manager to rewrite responses for app domains + // to answer with Magic IPs. + dnsManager.SetQueryResponseMapper(func(bs []byte) []byte { + if !e.conn25.isConfigured() { + return bs + } + return e.conn25.mapDNSResponse(bs) + }) + + // Intercept packets from the tun device and from WireGuard + // to perform DNAT and SNAT. + tun.PreFilterPacketOutboundToWireGuardAppConnectorIntercept = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response { + if !e.conn25.isConfigured() { + return filter.Accept + } + return dph.HandlePacketFromTunDevice(p) + } + tun.PostFilterPacketInboundFromWireGuardAppConnector = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response { + if !e.conn25.isConfigured() { + return filter.Accept + } + return dph.HandlePacketFromWireGuard(p) + } + + // Manage how we react to changes to the current node, + // including property changes (e.g. HostInfo, Capabilities, CapMap) + // and profile switches. + e.host.Hooks().OnSelfChange.Add(e.onSelfChange) + + // Allow the client to send packets with Transit IP destinations + // in the link-local space. + e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.client.linkLocalAllow(p) + }) + + // Allow the connector to receive packets with Transit IP destinations + // in the link-local space. + e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.connector.packetFilterAllow(p) + }) + + // Allow the connector to receive packets with Transit IP destinations + // that are not "local" to it, and that it does not advertise. + e.host.Hooks().Filter.IngressAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.connector.packetFilterAllow(p) + }) + + // Give the client the Magic IP range to install on the OS. + e.host.Hooks().ExtraRouterConfigRoutes.Set(func() views.Slice[netip.Prefix] { + if !e.conn25.isConfigured() { + return views.Slice[netip.Prefix]{} + } + return e.getMagicRange() + }) + + // Tell WireGuard what Transit IPs belong to which connector peers. + e.host.Hooks().ExtraWireGuardAllowedIPs.Set(func(k key.NodePublic) views.Slice[netip.Prefix] { + if !e.conn25.isConfigured() { + return views.Slice[netip.Prefix]{} + } + return e.extraWireGuardAllowedIPs(k) + }) + + return nil +} + +// ClientTransitIPForMagicIP implements [IPMapper]. +func (c *Conn25) ClientTransitIPForMagicIP(m netip.Addr) (netip.Addr, error) { + return c.client.transitIPForMagicIP(m) +} + +// ConnectorRealIPForTransitIPConnection implements [IPMapper]. +func (c *Conn25) ConnectorRealIPForTransitIPConnection(src, transit netip.Addr) (netip.Addr, error) { + return c.connector.realIPForTransitIPConnection(src, transit) +} + func (e *extension) getMagicRange() views.Slice[netip.Prefix] { cfg := e.conn25.client.getConfig() return views.SliceOf(cfg.magicIPSet.Prefixes()) @@ -177,56 +270,12 @@ func (e *extension) onSelfChange(selfNode tailcfg.NodeView) { e.conn25.client.logf("error during Reconfig onSelfChange: %v", err) return } - - if e.conn25.isConfigured() { - err = e.registerDNSHook() - } else { - err = e.unregisterDNSHook() - } - if err != nil { - e.conn25.client.logf("error managing DNS hook onSelfChange: %v", err) - } } func (e *extension) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { return e.conn25.client.extraWireGuardAllowedIPs(k) } -func (e *extension) registerDNSHook() error { - e.mu.Lock() - defer e.mu.Unlock() - if e.isDNSHookRegistered { - return nil - } - err := e.setDNSHookLocked(e.conn25.mapDNSResponse) - if err == nil { - e.isDNSHookRegistered = true - } - return err -} - -func (e *extension) unregisterDNSHook() error { - e.mu.Lock() - defer e.mu.Unlock() - if !e.isDNSHookRegistered { - return nil - } - err := e.setDNSHookLocked(nil) - if err == nil { - e.isDNSHookRegistered = false - } - return err -} - -func (e *extension) setDNSHookLocked(fx dns.ResponseMapper) error { - dnsManager, ok := e.backend.Sys().DNSManager.GetOK() - if !ok || dnsManager == nil { - return errors.New("couldn't get DNSManager from sys") - } - dnsManager.SetQueryResponseMapper(fx) - return nil -} - type appAddr struct { app string addr netip.Addr @@ -517,8 +566,9 @@ func (c *client) getConfig() config { return c.config } -// ClientTransitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups. -func (c *client) ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { +// transitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups. +// See also [IPMapper.ClientTransitIPForMagicIP]. +func (c *client) transitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { c.mu.Lock() defer c.mu.Unlock() v, ok := c.assignments.lookupByMagicIP(magicIP) @@ -938,8 +988,9 @@ type connector struct { config config } -// ConnectorRealIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups. -func (c *connector) ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { +// realIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups. +// See also [IPMapper.ConnectorRealIPForTransitIPConnection]. +func (c *connector) realIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { c.mu.Lock() defer c.mu.Unlock() v, ok := c.lookupBySrcIPAndTransitIP(srcIP, transitIP) diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 4e2f8a073..0a90c151a 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -18,8 +18,10 @@ "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/ipn/ipnext" + "tailscale.com/net/dns" "tailscale.com/net/packet" "tailscale.com/net/tsdial" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/appctype" @@ -818,6 +820,16 @@ type testSafeBackend struct { sys *tsd.System } +func newTestSafeBackend() *testSafeBackend { + sb := &testSafeBackend{} + sys := &tsd.System{} + sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) + sys.DNSManager.Set(&dns.Manager{}) + sys.Tun.Set(&tstun.Wrapper{}) + sb.sys = sys + return sb +} + func (b *testSafeBackend) Sys() *tsd.System { return b.sys } // TestAddressAssignmentIsHandled tests that after enqueueAddress has been called @@ -852,13 +864,9 @@ func TestAddressAssignmentIsHandled(t *testing.T) { Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 1: 0xff, 31: 0x01})), }).View() - // make extension to test - sys := &tsd.System{} - sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) - ext := &extension{ conn25: newConn25(logger.Discard), - backend: &testSafeBackend{sys: sys}, + backend: newTestSafeBackend(), } authReconfigAsyncCalled := make(chan struct{}, 1) if err := ext.Init(&testHost{ @@ -1243,13 +1251,9 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { }).View(), } - // make extension to test - sys := &tsd.System{} - sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) - ext := &extension{ conn25: newConn25(logger.Discard), - backend: &testSafeBackend{sys: sys}, + backend: newTestSafeBackend(), } authReconfigAsyncCalled := make(chan struct{}, 1) if err := ext.Init(&testHost{ @@ -1494,7 +1498,7 @@ func TestClientTransitIPForMagicIP(t *testing.T) { magic: mappedMip, transit: mappedTip, }) - tip, err := c.client.ClientTransitIPForMagicIP(tt.mip) + tip, err := c.client.transitIPForMagicIP(tt.mip) if tip != tt.wantTip { t.Fatalf("checking transit ip: want %v, got %v", tt.wantTip, tip) } @@ -1566,7 +1570,7 @@ func TestConnectorRealIPForTransitIPConnection(t *testing.T) { c.connector.transitIPs = map[netip.Addr]map[netip.Addr]appAddr{} c.connector.transitIPs[mappedSrc] = map[netip.Addr]appAddr{} c.connector.transitIPs[mappedSrc][mappedTip] = appAddr{addr: mappedMip} - mip, err := c.connector.ConnectorRealIPForTransitIPConnection(tt.src, tt.tip) + mip, err := c.connector.realIPForTransitIPConnection(tt.src, tt.tip) if mip != tt.wantMip { t.Fatalf("checking magic ip: want %v, got %v", tt.wantMip, mip) } diff --git a/feature/conn25/datapath.go b/feature/conn25/datapath.go index cc45edf63..b5cdd5155 100644 --- a/feature/conn25/datapath.go +++ b/feature/conn25/datapath.go @@ -37,8 +37,8 @@ type IPMapper interface { // range for an app on the connector, but not mapped to the client at srcIP, implementations // should return [ErrUnmappedSrcAndTransitIP]. If the transitIP is not within a configured // Transit IP range, i.e. it is not actually a Transit IP, implementations should return - // a nil error, a zero-value [netip.Addr] to indicate this is potentially valid, non-app-connector - // traffic. + // a nil error, and a zero-value [netip.Addr] to indicate this is potentially valid, + // non-app-connector traffic. ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) } diff --git a/feature/conn25/datapath_test.go b/feature/conn25/datapath_test.go index a4a3363b7..f75b89d29 100644 --- a/feature/conn25/datapath_test.go +++ b/feature/conn25/datapath_test.go @@ -112,7 +112,7 @@ func TestHandlePacketFromTunDevice(t *testing.T) { } return netip.Addr{}, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) tt.p.IPProto = ipproto.UDP tt.p.IPVersion = 4 @@ -217,7 +217,7 @@ func TestHandlePacketFromWireGuard(t *testing.T) { } return netip.Addr{}, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) tt.p.IPProto = ipproto.UDP tt.p.IPVersion = 4 @@ -254,7 +254,7 @@ func TestClientFlowCache(t *testing.T) { getTransitIPCalled = true return transitIP, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) outgoing := packet.Parsed{ IPProto: ipproto.UDP, @@ -316,7 +316,7 @@ func TestConnectorFlowCache(t *testing.T) { getRealIPCalled = true return realIP, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) outgoing := packet.Parsed{ IPProto: ipproto.UDP,