diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index fb83e86ff..c9ce1ad81 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -32,6 +32,7 @@ "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/appctype" + "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/views" "tailscale.com/util/dnsname" @@ -123,6 +124,7 @@ func (e *extension) Init(host ipnext.Host) error { return nil } e.host = host + host.Hooks().OnSelfChange.Add(e.onSelfChange) host.Hooks().ExtraRouterConfigRoutes.Set(e.getMagicRange) ctx, cancel := context.WithCancelCause(context.Background()) @@ -269,8 +271,7 @@ func (c *Conn25) reconfig(selfNode tailcfg.NodeView) error { } // mapDNSResponse parses and inspects the DNS response, and uses the -// contents to assign addresses for connecting. It does not yet modify -// the response. +// contents to assign addresses for connecting. func (c *Conn25) mapDNSResponse(buf []byte) []byte { return c.client.mapDNSResponse(buf) } @@ -610,6 +611,16 @@ func (c *client) reserveAddresses(domain dnsname.FQDN, dst netip.Addr) (addrs, e return as, nil } +func (c *client) addTransitIPForConnector(tip netip.Addr, conn tailcfg.NodeView) error { + if conn.Key().IsZero() { + return fmt.Errorf("node with stable ID %q does not have a key", conn.StableID()) + } + + c.mu.Lock() + defer c.mu.Unlock() + return c.assignments.insertTransitConnMapping(tip, conn.Key()) +} + func (e *extension) sendLoop(ctx context.Context) { for { select { @@ -624,10 +635,15 @@ func (e *extension) sendLoop(ctx context.Context) { } func (e *extension) handleAddressAssignment(ctx context.Context, as addrs) error { - if err := e.sendAddressAssignment(ctx, as); err != nil { + conn, err := e.sendAddressAssignment(ctx, as) + if err != nil { return err } - // TODO(fran) assign the connector publickey -> transit ip addresses + err = e.conn25.client.addTransitIPForConnector(as.transit, conn) + if err != nil { + return err + } + e.host.AuthReconfigAsync() return nil } @@ -685,27 +701,29 @@ func makePeerAPIReq(ctx context.Context, httpClient *http.Client, urlBase string return nil } -func (e *extension) sendAddressAssignment(ctx context.Context, as addrs) error { +func (e *extension) sendAddressAssignment(ctx context.Context, as addrs) (tailcfg.NodeView, error) { app, ok := e.conn25.client.getConfig().appsByName[as.app] if !ok { e.conn25.client.logf("App not found for app: %s (domain: %s)", as.app, as.domain) - return errors.New("app not found") + return tailcfg.NodeView{}, errors.New("app not found") } nb := e.host.NodeBackend() peers := appc.PickConnector(nb, app) var urlBase string + var conn tailcfg.NodeView for _, p := range peers { urlBase = nb.PeerAPIBase(p) if urlBase != "" { + conn = p break } } if urlBase == "" { - return errors.New("no connector peer found to handle address assignment") + return tailcfg.NodeView{}, errors.New("no connector peer found to handle address assignment") } client := e.backend.Sys().Dialer.Get().PeerAPIHTTPClient() - return makePeerAPIReq(ctx, client, urlBase, as) + return conn, makePeerAPIReq(ctx, client, urlBase, as) } type dnsResponseRewrite struct { @@ -866,6 +884,7 @@ func (c *client) rewriteDNSResponse(hdr dnsmessage.Header, questions []dnsmessag if err := b.StartAnswers(); err != nil { return nil, err } + // make an answer for each rewrite for _, rw := range answers { as, err := c.reserveAddresses(rw.domain, rw.dst) @@ -968,11 +987,15 @@ type domainDst struct { } // addrAssignments is the collection of addrs assigned by this client -// supporting lookup by magic IP, transit IP or domain+dst +// supporting lookup by magic IP, transit IP or domain+dst, or to lookup all +// transit IPs associated with a given connector (identified by its node key). +// byConnKey stores netip.Prefix versions of the transit IPs for use in the +// WireGuard hooks. type addrAssignments struct { byMagicIP map[netip.Addr]addrs byTransitIP map[netip.Addr]addrs byDomainDst map[domainDst]addrs + byConnKey map[key.NodePublic]set.Set[netip.Prefix] } func (a *addrAssignments) insert(as addrs) error { @@ -988,12 +1011,35 @@ func (a *addrAssignments) insert(as addrs) error { if _, ok := a.byTransitIP[as.transit]; ok { return errors.New("byTransitIP key exists") } + mak.Set(&a.byMagicIP, as.magic, as) mak.Set(&a.byTransitIP, as.transit, as) mak.Set(&a.byDomainDst, ddst, as) return nil } +// insertTransitConnMapping adds an entry to the byConnKey map +// for the provided transitIP (as a prefix). +// The provided transitIP must already be present in the byTransitIP map. +func (a *addrAssignments) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error { + if _, ok := a.lookupByTransitIP(tip); !ok { + return errors.New("transit IP is not already known") + } + + ctips, ok := a.byConnKey[connKey] + tipp := netip.PrefixFrom(tip, tip.BitLen()) + if ok { + if ctips.Contains(tipp) { + return errors.New("byConnKey already contains transit") + } + } else { + ctips.Make() + mak.Set(&a.byConnKey, connKey, ctips) + } + ctips.Add(tipp) + return nil +} + func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) { v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] return v, ok @@ -1008,3 +1054,14 @@ func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) { v, ok := a.byTransitIP[tip] return v, ok } + +// lookupTransitIPsByConnKey returns a slice containing the transit IPs (as netipPrefix) +// associated with the given connector (identified by node key), or (nil, false) if there is no entry +// for the given key. +func (a *addrAssignments) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) { + s, ok := a.byConnKey[k] + if !ok { + return nil, false + } + return s.Slice(), true +} diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index ff3ec4c9e..4e2f8a073 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -8,11 +8,13 @@ "net/http" "net/http/httptest" "net/netip" + "slices" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "go4.org/mem" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/ipn/ipnext" @@ -21,6 +23,7 @@ "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/appctype" + "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/opt" "tailscale.com/util/dnsname" @@ -525,13 +528,16 @@ func TestConfigReconfig(t *testing.T) { } } -func makeSelfNode(t *testing.T, attr appctype.Conn25Attr, tags []string) tailcfg.NodeView { +func makeSelfNode(t *testing.T, attrs []appctype.Conn25Attr, tags []string) tailcfg.NodeView { t.Helper() - bs, err := json.Marshal(attr) - if err != nil { - t.Fatalf("unexpected error in test setup: %v", err) + cfg := make([]tailcfg.RawMessage, 0, len(attrs)) + for i, attr := range attrs { + bs, err := json.Marshal(attr) + if err != nil { + t.Fatalf("unexpected error in test setup at index %d: %v", i, err) + } + cfg = append(cfg, tailcfg.RawMessage(bs)) } - cfg := []tailcfg.RawMessage{tailcfg.RawMessage(bs)} capMap := tailcfg.NodeCapMap{ tailcfg.NodeCapability(AppConnectorsExperimentalAttrName): cfg, } @@ -727,13 +733,13 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { dnsResp := makeDNSResponse(t, tt.domain, tt.addrs) - sn := makeSelfNode(t, appctype.Conn25Attr{ + sn := makeSelfNode(t, []appctype.Conn25Attr{{ Name: "app1", Connectors: []string{"tag:woo"}, Domains: []string{"example.com"}, MagicIPPool: []netipx.IPRange{rangeFrom("0", "10"), rangeFrom("20", "30")}, TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")}, - }, []string{}) + }}, []string{}) c := newConn25(logger.Discard) c.reconfig(sn) @@ -843,6 +849,7 @@ func TestAddressAssignmentIsHandled(t *testing.T) { ID: tailcfg.NodeID(1), Tags: []string{"tag:woo"}, Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 1: 0xff, 31: 0x01})), }).View() // make extension to test @@ -867,11 +874,11 @@ func TestAddressAssignmentIsHandled(t *testing.T) { } defer ext.Shutdown() - sn := makeSelfNode(t, appctype.Conn25Attr{ + sn := makeSelfNode(t, []appctype.Conn25Attr{{ Name: "app1", Connectors: []string{"tag:woo"}, Domains: []string{"example.com"}, - }, []string{}) + }}, []string{}) err := ext.conn25.reconfig(sn) if err != nil { t.Fatal(err) @@ -884,6 +891,9 @@ func TestAddressAssignmentIsHandled(t *testing.T) { domain: "example.com.", app: "app1", } + if err := ext.conn25.client.assignments.insert(as); err != nil { + t.Fatalf("error inserting address assignments: %v", err) + } ext.conn25.client.enqueueAddressAssignment(as) select { @@ -942,13 +952,13 @@ func TestMapDNSResponseRewritesResponses(t *testing.T) { configuredDomain := "example.com" domainName := configuredDomain + "." dnsMessageName := dnsmessage.MustNewName(domainName) - sn := makeSelfNode(t, appctype.Conn25Attr{ + sn := makeSelfNode(t, []appctype.Conn25Attr{{ Name: "app1", Connectors: []string{"tag:connector"}, Domains: []string{configuredDomain}, MagicIPPool: []netipx.IPRange{rangeFrom("0", "10")}, TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")}, - }, []string{}) + }}, []string{}) compareToRecords := func(t *testing.T, resources []dnsmessage.Resource, want []netip.Addr) { t.Helper() @@ -1199,10 +1209,253 @@ func TestMapDNSResponseRewritesResponses(t *testing.T) { } } -func TestClientTransitIPForMagicIP(t *testing.T) { - sn := makeSelfNode(t, appctype.Conn25Attr{ - MagicIPPool: []netipx.IPRange{rangeFrom("0", "10")}, // 100.64.0.0 - 100.64.0.10 +func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { + // make a fake peer API to test against, for all peers + peersAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v0/connector/transit-ip" { + http.Error(w, "unexpected path", http.StatusNotFound) + return + } + var req ConnectorTransitIPRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad body", http.StatusBadRequest) + return + } + resp := ConnectorTransitIPResponse{ + TransitIPs: []TransitIPResponse{{Code: OK}}, + } + json.NewEncoder(w).Encode(resp) + })) + defer peersAPI.Close() + + connectorPeers := []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: tailcfg.NodeID(1), + Tags: []string{"tag:woo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x01})), + }).View(), + (&tailcfg.Node{ + ID: tailcfg.NodeID(2), + Tags: []string{"tag:hoo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x02})), + }).View(), + } + + // make extension to test + sys := &tsd.System{} + sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) + + ext := &extension{ + conn25: newConn25(logger.Discard), + backend: &testSafeBackend{sys: sys}, + } + authReconfigAsyncCalled := make(chan struct{}, 1) + if err := ext.Init(&testHost{ + nb: &testNodeBackend{ + peers: connectorPeers, + peerAPIURL: peersAPI.URL, + }, + authReconfigAsync: func() { + authReconfigAsyncCalled <- struct{}{} + }, + }); err != nil { + t.Fatal(err) + } + defer ext.Shutdown() + + sn := makeSelfNode(t, []appctype.Conn25Attr{ + { + Name: "app1", + Connectors: []string{"tag:woo"}, + Domains: []string{"woo.example.com"}, + }, + { + Name: "app2", + Connectors: []string{"tag:hoo"}, + Domains: []string{"hoo.example.com"}, + }, }, []string{}) + err := ext.conn25.reconfig(sn) + if err != nil { + t.Fatal(err) + } + + type lookup struct { + connKey key.NodePublic + expectedIPs []netip.Prefix + expectedOk bool + } + + transitIPs := []netip.Prefix{ + netip.MustParsePrefix("169.254.0.1/32"), + netip.MustParsePrefix("169.254.0.2/32"), + netip.MustParsePrefix("169.254.0.3/32"), + } + // Each step performs an insert on the provided addrs + // and then does the lookups. + steps := []struct { + name string + as addrs + lookups []lookup + }{ + { + name: "step-1-conn1-tip1", + as: addrs{ + dst: netip.MustParseAddr("1.2.3.1"), + magic: netip.MustParseAddr("100.64.0.1"), + transit: transitIPs[0].Addr(), + domain: "woo.example.com.", + app: "app1", + }, + lookups: []lookup{ + { + connKey: connectorPeers[0].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[0], + }, + expectedOk: true, + }, + { + connKey: connectorPeers[1].Key(), + expectedIPs: nil, + expectedOk: false, + }, + }, + }, + { + name: "step-2-conn1-tip2", + as: addrs{ + dst: netip.MustParseAddr("1.2.3.2"), + magic: netip.MustParseAddr("100.64.0.2"), + transit: transitIPs[1].Addr(), + domain: "woo.example.com.", + app: "app1", + }, + lookups: []lookup{ + { + connKey: connectorPeers[0].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[0], + transitIPs[1], + }, + expectedOk: true, + }, + }, + }, + { + name: "step-3-conn2-tip1", + as: addrs{ + dst: netip.MustParseAddr("1.2.3.3"), + magic: netip.MustParseAddr("100.64.0.3"), + transit: transitIPs[2].Addr(), + domain: "hoo.example.com.", + app: "app2", + }, + lookups: []lookup{ + { + connKey: connectorPeers[0].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[0], + transitIPs[1], + }, + expectedOk: true, + }, + { + connKey: connectorPeers[1].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[2], + }, + expectedOk: true, + }, + }, + }, + } + + for _, tt := range steps { + t.Run(tt.name, func(t *testing.T) { + // Add and enqueue the addrs, and then wait for the send to complete + // (as indicated by authReconfig being called). + if err := ext.conn25.client.assignments.insert(tt.as); err != nil { + t.Fatalf("error inserting address assignment: %v", err) + } + if err := ext.conn25.client.enqueueAddressAssignment(tt.as); err != nil { + t.Fatalf("error enqueuing address assignment: %v", err) + } + select { + case <-authReconfigAsyncCalled: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for AuthReconfigAsync to be called") + } + + // Check that each of the lookups behaves as expected + for i, lu := range tt.lookups { + got, ok := ext.conn25.client.assignments.lookupTransitIPsByConnKey(lu.connKey) + if ok != lu.expectedOk { + t.Fatalf("unexpected ok result at index %d wanted %v, got %v", i, lu.expectedOk, ok) + } + slices.SortFunc(got, func(a, b netip.Prefix) int { return a.Compare(b) }) + if diff := cmp.Diff(lu.expectedIPs, got, cmpopts.EquateComparable(netip.Prefix{})); diff != "" { + t.Fatalf("transit IPs mismatch at index %d, (-want +got):\n%s", i, diff) + } + } + }) + } +} + +func TestTransitIPConnMapping(t *testing.T) { + conn25 := newConn25(t.Logf) + + as := addrs{ + dst: netip.MustParseAddr("1.2.3.1"), + magic: netip.MustParseAddr("100.64.0.1"), + transit: netip.MustParseAddr("169.254.0.1"), + domain: "woo.example.com.", + app: "app1", + } + + connectorPeers := []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: tailcfg.NodeID(0), + Tags: []string{"tag:woo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublic{}, + }).View(), + (&tailcfg.Node{ + ID: tailcfg.NodeID(2), + Tags: []string{"tag:hoo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x02})), + }).View(), + } + + // Adding a transit IP that isn't known should fail + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err == nil { + t.Error("adding an unknown transit IP should fail") + } + + // Insert the address assignments + conn25.client.assignments.insert(as) + + // Adding a transit IP for a node with an unset key should fail + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[0]); err == nil { + t.Error("adding an transit IP mapping for a connector with a zero key should fail") + } + // Adding a transit IP that is known should succeed + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err != nil { + t.Errorf("unexpected error for first time add: %v", err) + } + // But doing it again should fail + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err == nil { + t.Error("adding a duplicate transitIP for a connector should fail") + } +} + +func TestClientTransitIPForMagicIP(t *testing.T) { + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + MagicIPPool: []netipx.IPRange{rangeFrom("0", "10")}, // 100.64.0.0 - 100.64.0.10 + }}, []string{}) mappedMip := netip.MustParseAddr("100.64.0.0") mappedTip := netip.MustParseAddr("169.0.0.0") unmappedMip := netip.MustParseAddr("100.64.0.1") @@ -1253,9 +1506,9 @@ func TestClientTransitIPForMagicIP(t *testing.T) { } func TestConnectorRealIPForTransitIPConnection(t *testing.T) { - sn := makeSelfNode(t, appctype.Conn25Attr{ + sn := makeSelfNode(t, []appctype.Conn25Attr{{ TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")}, // 100.64.0.40 - 100.64.0.50 - }, []string{}) + }}, []string{}) mappedSrc := netip.MustParseAddr("100.0.0.1") unmappedSrc := netip.MustParseAddr("100.0.0.2") mappedTip := netip.MustParseAddr("100.64.0.41")