diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index 98af443d4..eeb02c5f8 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -27,9 +27,9 @@ "tailscale.com/feature" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" - "tailscale.com/net/dns" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/types/appctype" "tailscale.com/types/key" @@ -39,6 +39,7 @@ "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/testenv" + "tailscale.com/wgengine/filter" ) // featureName is the name of the feature implemented by this package. @@ -100,9 +101,6 @@ type extension struct { host ipnext.Host // set in Init, read-only after ctxCancel context.CancelCauseFunc // cancels sendLoop goroutine - - mu sync.Mutex // protects the fields below - isDNSHookRegistered bool } // Name implements [ipnext.Extension]. @@ -117,23 +115,118 @@ func (e *extension) Init(host ipnext.Host) error { return ipnext.SkipExtension } - //Init only once - e.mu.Lock() - defer e.mu.Unlock() if e.ctxCancel != nil { return nil } e.host = host - host.Hooks().OnSelfChange.Add(e.onSelfChange) - host.Hooks().ExtraRouterConfigRoutes.Set(e.getMagicRange) - host.Hooks().ExtraWireGuardAllowedIPs.Set(e.extraWireGuardAllowedIPs) + dph := newDatapathHandler(e.conn25, e.conn25.client.logf) + if err := e.installHooks(dph); err != nil { + return err + } + ctx, cancel := context.WithCancelCause(context.Background()) e.ctxCancel = cancel go e.sendLoop(ctx) return nil } +func (e *extension) installHooks(dph *datapathHandler) error { + // Make sure we can access the DNS manager and the system tun. + dnsManager, ok := e.backend.Sys().DNSManager.GetOK() + if !ok { + return errors.New("could not access system dns manager") + } + tun, ok := e.backend.Sys().Tun.GetOK() + if !ok { + return errors.New("could not access system tun") + } + + // Set up the DNS manager to rewrite responses for app domains + // to answer with Magic IPs. + dnsManager.SetQueryResponseMapper(func(bs []byte) []byte { + if !e.conn25.isConfigured() { + return bs + } + return e.conn25.mapDNSResponse(bs) + }) + + // Intercept packets from the tun device and from WireGuard + // to perform DNAT and SNAT. + tun.PreFilterPacketOutboundToWireGuardAppConnectorIntercept = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response { + if !e.conn25.isConfigured() { + return filter.Accept + } + return dph.HandlePacketFromTunDevice(p) + } + tun.PostFilterPacketInboundFromWireGuardAppConnector = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response { + if !e.conn25.isConfigured() { + return filter.Accept + } + return dph.HandlePacketFromWireGuard(p) + } + + // Manage how we react to changes to the current node, + // including property changes (e.g. HostInfo, Capabilities, CapMap) + // and profile switches. + e.host.Hooks().OnSelfChange.Add(e.onSelfChange) + + // Allow the client to send packets with Transit IP destinations + // in the link-local space. + e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.client.linkLocalAllow(p) + }) + + // Allow the connector to receive packets with Transit IP destinations + // in the link-local space. + e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.connector.packetFilterAllow(p) + }) + + // Allow the connector to receive packets with Transit IP destinations + // that are not "local" to it, and that it does not advertise. + e.host.Hooks().Filter.IngressAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.connector.packetFilterAllow(p) + }) + + // Give the client the Magic IP range to install on the OS. + e.host.Hooks().ExtraRouterConfigRoutes.Set(func() views.Slice[netip.Prefix] { + if !e.conn25.isConfigured() { + return views.Slice[netip.Prefix]{} + } + return e.getMagicRange() + }) + + // Tell WireGuard what Transit IPs belong to which connector peers. + e.host.Hooks().ExtraWireGuardAllowedIPs.Set(func(k key.NodePublic) views.Slice[netip.Prefix] { + if !e.conn25.isConfigured() { + return views.Slice[netip.Prefix]{} + } + return e.extraWireGuardAllowedIPs(k) + }) + + return nil +} + +// ClientTransitIPForMagicIP implements [IPMapper]. +func (c *Conn25) ClientTransitIPForMagicIP(m netip.Addr) (netip.Addr, error) { + return c.client.transitIPForMagicIP(m) +} + +// ConnectorRealIPForTransitIPConnection implements [IPMapper]. +func (c *Conn25) ConnectorRealIPForTransitIPConnection(src, transit netip.Addr) (netip.Addr, error) { + return c.connector.realIPForTransitIPConnection(src, transit) +} + func (e *extension) getMagicRange() views.Slice[netip.Prefix] { cfg := e.conn25.client.getConfig() return views.SliceOf(cfg.magicIPSet.Prefixes()) @@ -177,56 +270,12 @@ func (e *extension) onSelfChange(selfNode tailcfg.NodeView) { e.conn25.client.logf("error during Reconfig onSelfChange: %v", err) return } - - if e.conn25.isConfigured() { - err = e.registerDNSHook() - } else { - err = e.unregisterDNSHook() - } - if err != nil { - e.conn25.client.logf("error managing DNS hook onSelfChange: %v", err) - } } func (e *extension) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { return e.conn25.client.extraWireGuardAllowedIPs(k) } -func (e *extension) registerDNSHook() error { - e.mu.Lock() - defer e.mu.Unlock() - if e.isDNSHookRegistered { - return nil - } - err := e.setDNSHookLocked(e.conn25.mapDNSResponse) - if err == nil { - e.isDNSHookRegistered = true - } - return err -} - -func (e *extension) unregisterDNSHook() error { - e.mu.Lock() - defer e.mu.Unlock() - if !e.isDNSHookRegistered { - return nil - } - err := e.setDNSHookLocked(nil) - if err == nil { - e.isDNSHookRegistered = false - } - return err -} - -func (e *extension) setDNSHookLocked(fx dns.ResponseMapper) error { - dnsManager, ok := e.backend.Sys().DNSManager.GetOK() - if !ok || dnsManager == nil { - return errors.New("couldn't get DNSManager from sys") - } - dnsManager.SetQueryResponseMapper(fx) - return nil -} - type appAddr struct { app string addr netip.Addr @@ -517,8 +566,9 @@ func (c *client) getConfig() config { return c.config } -// ClientTransitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups. -func (c *client) ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { +// transitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups. +// See also [IPMapper.ClientTransitIPForMagicIP]. +func (c *client) transitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { c.mu.Lock() defer c.mu.Unlock() v, ok := c.assignments.lookupByMagicIP(magicIP) @@ -938,8 +988,9 @@ type connector struct { config config } -// ConnectorRealIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups. -func (c *connector) ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { +// realIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups. +// See also [IPMapper.ConnectorRealIPForTransitIPConnection]. +func (c *connector) realIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { c.mu.Lock() defer c.mu.Unlock() v, ok := c.lookupBySrcIPAndTransitIP(srcIP, transitIP) diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 4e2f8a073..0a90c151a 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -18,8 +18,10 @@ "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/ipn/ipnext" + "tailscale.com/net/dns" "tailscale.com/net/packet" "tailscale.com/net/tsdial" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/appctype" @@ -818,6 +820,16 @@ type testSafeBackend struct { sys *tsd.System } +func newTestSafeBackend() *testSafeBackend { + sb := &testSafeBackend{} + sys := &tsd.System{} + sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) + sys.DNSManager.Set(&dns.Manager{}) + sys.Tun.Set(&tstun.Wrapper{}) + sb.sys = sys + return sb +} + func (b *testSafeBackend) Sys() *tsd.System { return b.sys } // TestAddressAssignmentIsHandled tests that after enqueueAddress has been called @@ -852,13 +864,9 @@ func TestAddressAssignmentIsHandled(t *testing.T) { Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 1: 0xff, 31: 0x01})), }).View() - // make extension to test - sys := &tsd.System{} - sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) - ext := &extension{ conn25: newConn25(logger.Discard), - backend: &testSafeBackend{sys: sys}, + backend: newTestSafeBackend(), } authReconfigAsyncCalled := make(chan struct{}, 1) if err := ext.Init(&testHost{ @@ -1243,13 +1251,9 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { }).View(), } - // make extension to test - sys := &tsd.System{} - sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) - ext := &extension{ conn25: newConn25(logger.Discard), - backend: &testSafeBackend{sys: sys}, + backend: newTestSafeBackend(), } authReconfigAsyncCalled := make(chan struct{}, 1) if err := ext.Init(&testHost{ @@ -1494,7 +1498,7 @@ func TestClientTransitIPForMagicIP(t *testing.T) { magic: mappedMip, transit: mappedTip, }) - tip, err := c.client.ClientTransitIPForMagicIP(tt.mip) + tip, err := c.client.transitIPForMagicIP(tt.mip) if tip != tt.wantTip { t.Fatalf("checking transit ip: want %v, got %v", tt.wantTip, tip) } @@ -1566,7 +1570,7 @@ func TestConnectorRealIPForTransitIPConnection(t *testing.T) { c.connector.transitIPs = map[netip.Addr]map[netip.Addr]appAddr{} c.connector.transitIPs[mappedSrc] = map[netip.Addr]appAddr{} c.connector.transitIPs[mappedSrc][mappedTip] = appAddr{addr: mappedMip} - mip, err := c.connector.ConnectorRealIPForTransitIPConnection(tt.src, tt.tip) + mip, err := c.connector.realIPForTransitIPConnection(tt.src, tt.tip) if mip != tt.wantMip { t.Fatalf("checking magic ip: want %v, got %v", tt.wantMip, mip) } diff --git a/feature/conn25/datapath.go b/feature/conn25/datapath.go index cc45edf63..b5cdd5155 100644 --- a/feature/conn25/datapath.go +++ b/feature/conn25/datapath.go @@ -37,8 +37,8 @@ type IPMapper interface { // range for an app on the connector, but not mapped to the client at srcIP, implementations // should return [ErrUnmappedSrcAndTransitIP]. If the transitIP is not within a configured // Transit IP range, i.e. it is not actually a Transit IP, implementations should return - // a nil error, a zero-value [netip.Addr] to indicate this is potentially valid, non-app-connector - // traffic. + // a nil error, and a zero-value [netip.Addr] to indicate this is potentially valid, + // non-app-connector traffic. ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) } diff --git a/feature/conn25/datapath_test.go b/feature/conn25/datapath_test.go index a4a3363b7..f75b89d29 100644 --- a/feature/conn25/datapath_test.go +++ b/feature/conn25/datapath_test.go @@ -112,7 +112,7 @@ func TestHandlePacketFromTunDevice(t *testing.T) { } return netip.Addr{}, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) tt.p.IPProto = ipproto.UDP tt.p.IPVersion = 4 @@ -217,7 +217,7 @@ func TestHandlePacketFromWireGuard(t *testing.T) { } return netip.Addr{}, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) tt.p.IPProto = ipproto.UDP tt.p.IPVersion = 4 @@ -254,7 +254,7 @@ func TestClientFlowCache(t *testing.T) { getTransitIPCalled = true return transitIP, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) outgoing := packet.Parsed{ IPProto: ipproto.UDP, @@ -316,7 +316,7 @@ func TestConnectorFlowCache(t *testing.T) { getRealIPCalled = true return realIP, nil } - dph := newDatapathHandler(mock, nil) + dph := newDatapathHandler(mock, t.Logf) outgoing := packet.Parsed{ IPProto: ipproto.UDP,