feature/conn25: Store transit ips by connector key (#19071)

The client needs to know the set of transit IPs that are assigned
to each connector, so when we register transit IPs with the connector
we also need to assign them to that connector in the addrAssignments.
We identify the connector by node public key to match the peer information
that is available when the ExtraWireguardAllowedIPs hook will be invoked.

Fixes tailscale/corp#38127

Signed-off-by: George Jones <george@tailscale.com>
This commit is contained in:
George Jones
2026-03-26 15:58:26 -04:00
committed by GitHub
parent 4ace87a965
commit 86135d3df5
2 changed files with 335 additions and 25 deletions

View File

@@ -32,6 +32,7 @@
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/appctype"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/views"
"tailscale.com/util/dnsname"
@@ -123,6 +124,7 @@ func (e *extension) Init(host ipnext.Host) error {
return nil
}
e.host = host
host.Hooks().OnSelfChange.Add(e.onSelfChange)
host.Hooks().ExtraRouterConfigRoutes.Set(e.getMagicRange)
ctx, cancel := context.WithCancelCause(context.Background())
@@ -269,8 +271,7 @@ func (c *Conn25) reconfig(selfNode tailcfg.NodeView) error {
}
// mapDNSResponse parses and inspects the DNS response, and uses the
// contents to assign addresses for connecting. It does not yet modify
// the response.
// contents to assign addresses for connecting.
func (c *Conn25) mapDNSResponse(buf []byte) []byte {
return c.client.mapDNSResponse(buf)
}
@@ -610,6 +611,16 @@ func (c *client) reserveAddresses(domain dnsname.FQDN, dst netip.Addr) (addrs, e
return as, nil
}
func (c *client) addTransitIPForConnector(tip netip.Addr, conn tailcfg.NodeView) error {
if conn.Key().IsZero() {
return fmt.Errorf("node with stable ID %q does not have a key", conn.StableID())
}
c.mu.Lock()
defer c.mu.Unlock()
return c.assignments.insertTransitConnMapping(tip, conn.Key())
}
func (e *extension) sendLoop(ctx context.Context) {
for {
select {
@@ -624,10 +635,15 @@ func (e *extension) sendLoop(ctx context.Context) {
}
func (e *extension) handleAddressAssignment(ctx context.Context, as addrs) error {
if err := e.sendAddressAssignment(ctx, as); err != nil {
conn, err := e.sendAddressAssignment(ctx, as)
if err != nil {
return err
}
// TODO(fran) assign the connector publickey -> transit ip addresses
err = e.conn25.client.addTransitIPForConnector(as.transit, conn)
if err != nil {
return err
}
e.host.AuthReconfigAsync()
return nil
}
@@ -685,27 +701,29 @@ func makePeerAPIReq(ctx context.Context, httpClient *http.Client, urlBase string
return nil
}
func (e *extension) sendAddressAssignment(ctx context.Context, as addrs) error {
func (e *extension) sendAddressAssignment(ctx context.Context, as addrs) (tailcfg.NodeView, error) {
app, ok := e.conn25.client.getConfig().appsByName[as.app]
if !ok {
e.conn25.client.logf("App not found for app: %s (domain: %s)", as.app, as.domain)
return errors.New("app not found")
return tailcfg.NodeView{}, errors.New("app not found")
}
nb := e.host.NodeBackend()
peers := appc.PickConnector(nb, app)
var urlBase string
var conn tailcfg.NodeView
for _, p := range peers {
urlBase = nb.PeerAPIBase(p)
if urlBase != "" {
conn = p
break
}
}
if urlBase == "" {
return errors.New("no connector peer found to handle address assignment")
return tailcfg.NodeView{}, errors.New("no connector peer found to handle address assignment")
}
client := e.backend.Sys().Dialer.Get().PeerAPIHTTPClient()
return makePeerAPIReq(ctx, client, urlBase, as)
return conn, makePeerAPIReq(ctx, client, urlBase, as)
}
type dnsResponseRewrite struct {
@@ -866,6 +884,7 @@ func (c *client) rewriteDNSResponse(hdr dnsmessage.Header, questions []dnsmessag
if err := b.StartAnswers(); err != nil {
return nil, err
}
// make an answer for each rewrite
for _, rw := range answers {
as, err := c.reserveAddresses(rw.domain, rw.dst)
@@ -968,11 +987,15 @@ type domainDst struct {
}
// addrAssignments is the collection of addrs assigned by this client
// supporting lookup by magic IP, transit IP or domain+dst
// supporting lookup by magic IP, transit IP or domain+dst, or to lookup all
// transit IPs associated with a given connector (identified by its node key).
// byConnKey stores netip.Prefix versions of the transit IPs for use in the
// WireGuard hooks.
type addrAssignments struct {
byMagicIP map[netip.Addr]addrs
byTransitIP map[netip.Addr]addrs
byDomainDst map[domainDst]addrs
byConnKey map[key.NodePublic]set.Set[netip.Prefix]
}
func (a *addrAssignments) insert(as addrs) error {
@@ -988,12 +1011,35 @@ func (a *addrAssignments) insert(as addrs) error {
if _, ok := a.byTransitIP[as.transit]; ok {
return errors.New("byTransitIP key exists")
}
mak.Set(&a.byMagicIP, as.magic, as)
mak.Set(&a.byTransitIP, as.transit, as)
mak.Set(&a.byDomainDst, ddst, as)
return nil
}
// insertTransitConnMapping adds an entry to the byConnKey map
// for the provided transitIP (as a prefix).
// The provided transitIP must already be present in the byTransitIP map.
func (a *addrAssignments) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error {
if _, ok := a.lookupByTransitIP(tip); !ok {
return errors.New("transit IP is not already known")
}
ctips, ok := a.byConnKey[connKey]
tipp := netip.PrefixFrom(tip, tip.BitLen())
if ok {
if ctips.Contains(tipp) {
return errors.New("byConnKey already contains transit")
}
} else {
ctips.Make()
mak.Set(&a.byConnKey, connKey, ctips)
}
ctips.Add(tipp)
return nil
}
func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) {
v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}]
return v, ok
@@ -1008,3 +1054,14 @@ func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) {
v, ok := a.byTransitIP[tip]
return v, ok
}
// lookupTransitIPsByConnKey returns a slice containing the transit IPs (as netipPrefix)
// associated with the given connector (identified by node key), or (nil, false) if there is no entry
// for the given key.
func (a *addrAssignments) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) {
s, ok := a.byConnKey[k]
if !ok {
return nil, false
}
return s.Slice(), true
}

View File

@@ -8,11 +8,13 @@
"net/http"
"net/http/httptest"
"net/netip"
"slices"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go4.org/mem"
"go4.org/netipx"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/ipn/ipnext"
@@ -21,6 +23,7 @@
"tailscale.com/tailcfg"
"tailscale.com/tsd"
"tailscale.com/types/appctype"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/opt"
"tailscale.com/util/dnsname"
@@ -525,13 +528,16 @@ func TestConfigReconfig(t *testing.T) {
}
}
func makeSelfNode(t *testing.T, attr appctype.Conn25Attr, tags []string) tailcfg.NodeView {
func makeSelfNode(t *testing.T, attrs []appctype.Conn25Attr, tags []string) tailcfg.NodeView {
t.Helper()
bs, err := json.Marshal(attr)
if err != nil {
t.Fatalf("unexpected error in test setup: %v", err)
cfg := make([]tailcfg.RawMessage, 0, len(attrs))
for i, attr := range attrs {
bs, err := json.Marshal(attr)
if err != nil {
t.Fatalf("unexpected error in test setup at index %d: %v", i, err)
}
cfg = append(cfg, tailcfg.RawMessage(bs))
}
cfg := []tailcfg.RawMessage{tailcfg.RawMessage(bs)}
capMap := tailcfg.NodeCapMap{
tailcfg.NodeCapability(AppConnectorsExperimentalAttrName): cfg,
}
@@ -727,13 +733,13 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
dnsResp := makeDNSResponse(t, tt.domain, tt.addrs)
sn := makeSelfNode(t, appctype.Conn25Attr{
sn := makeSelfNode(t, []appctype.Conn25Attr{{
Name: "app1",
Connectors: []string{"tag:woo"},
Domains: []string{"example.com"},
MagicIPPool: []netipx.IPRange{rangeFrom("0", "10"), rangeFrom("20", "30")},
TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")},
}, []string{})
}}, []string{})
c := newConn25(logger.Discard)
c.reconfig(sn)
@@ -843,6 +849,7 @@ func TestAddressAssignmentIsHandled(t *testing.T) {
ID: tailcfg.NodeID(1),
Tags: []string{"tag:woo"},
Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(),
Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 1: 0xff, 31: 0x01})),
}).View()
// make extension to test
@@ -867,11 +874,11 @@ func TestAddressAssignmentIsHandled(t *testing.T) {
}
defer ext.Shutdown()
sn := makeSelfNode(t, appctype.Conn25Attr{
sn := makeSelfNode(t, []appctype.Conn25Attr{{
Name: "app1",
Connectors: []string{"tag:woo"},
Domains: []string{"example.com"},
}, []string{})
}}, []string{})
err := ext.conn25.reconfig(sn)
if err != nil {
t.Fatal(err)
@@ -884,6 +891,9 @@ func TestAddressAssignmentIsHandled(t *testing.T) {
domain: "example.com.",
app: "app1",
}
if err := ext.conn25.client.assignments.insert(as); err != nil {
t.Fatalf("error inserting address assignments: %v", err)
}
ext.conn25.client.enqueueAddressAssignment(as)
select {
@@ -942,13 +952,13 @@ func TestMapDNSResponseRewritesResponses(t *testing.T) {
configuredDomain := "example.com"
domainName := configuredDomain + "."
dnsMessageName := dnsmessage.MustNewName(domainName)
sn := makeSelfNode(t, appctype.Conn25Attr{
sn := makeSelfNode(t, []appctype.Conn25Attr{{
Name: "app1",
Connectors: []string{"tag:connector"},
Domains: []string{configuredDomain},
MagicIPPool: []netipx.IPRange{rangeFrom("0", "10")},
TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")},
}, []string{})
}}, []string{})
compareToRecords := func(t *testing.T, resources []dnsmessage.Resource, want []netip.Addr) {
t.Helper()
@@ -1199,10 +1209,253 @@ func TestMapDNSResponseRewritesResponses(t *testing.T) {
}
}
func TestClientTransitIPForMagicIP(t *testing.T) {
sn := makeSelfNode(t, appctype.Conn25Attr{
MagicIPPool: []netipx.IPRange{rangeFrom("0", "10")}, // 100.64.0.0 - 100.64.0.10
func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) {
// make a fake peer API to test against, for all peers
peersAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v0/connector/transit-ip" {
http.Error(w, "unexpected path", http.StatusNotFound)
return
}
var req ConnectorTransitIPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad body", http.StatusBadRequest)
return
}
resp := ConnectorTransitIPResponse{
TransitIPs: []TransitIPResponse{{Code: OK}},
}
json.NewEncoder(w).Encode(resp)
}))
defer peersAPI.Close()
connectorPeers := []tailcfg.NodeView{
(&tailcfg.Node{
ID: tailcfg.NodeID(1),
Tags: []string{"tag:woo"},
Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(),
Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x01})),
}).View(),
(&tailcfg.Node{
ID: tailcfg.NodeID(2),
Tags: []string{"tag:hoo"},
Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(),
Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x02})),
}).View(),
}
// make extension to test
sys := &tsd.System{}
sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard})
ext := &extension{
conn25: newConn25(logger.Discard),
backend: &testSafeBackend{sys: sys},
}
authReconfigAsyncCalled := make(chan struct{}, 1)
if err := ext.Init(&testHost{
nb: &testNodeBackend{
peers: connectorPeers,
peerAPIURL: peersAPI.URL,
},
authReconfigAsync: func() {
authReconfigAsyncCalled <- struct{}{}
},
}); err != nil {
t.Fatal(err)
}
defer ext.Shutdown()
sn := makeSelfNode(t, []appctype.Conn25Attr{
{
Name: "app1",
Connectors: []string{"tag:woo"},
Domains: []string{"woo.example.com"},
},
{
Name: "app2",
Connectors: []string{"tag:hoo"},
Domains: []string{"hoo.example.com"},
},
}, []string{})
err := ext.conn25.reconfig(sn)
if err != nil {
t.Fatal(err)
}
type lookup struct {
connKey key.NodePublic
expectedIPs []netip.Prefix
expectedOk bool
}
transitIPs := []netip.Prefix{
netip.MustParsePrefix("169.254.0.1/32"),
netip.MustParsePrefix("169.254.0.2/32"),
netip.MustParsePrefix("169.254.0.3/32"),
}
// Each step performs an insert on the provided addrs
// and then does the lookups.
steps := []struct {
name string
as addrs
lookups []lookup
}{
{
name: "step-1-conn1-tip1",
as: addrs{
dst: netip.MustParseAddr("1.2.3.1"),
magic: netip.MustParseAddr("100.64.0.1"),
transit: transitIPs[0].Addr(),
domain: "woo.example.com.",
app: "app1",
},
lookups: []lookup{
{
connKey: connectorPeers[0].Key(),
expectedIPs: []netip.Prefix{
transitIPs[0],
},
expectedOk: true,
},
{
connKey: connectorPeers[1].Key(),
expectedIPs: nil,
expectedOk: false,
},
},
},
{
name: "step-2-conn1-tip2",
as: addrs{
dst: netip.MustParseAddr("1.2.3.2"),
magic: netip.MustParseAddr("100.64.0.2"),
transit: transitIPs[1].Addr(),
domain: "woo.example.com.",
app: "app1",
},
lookups: []lookup{
{
connKey: connectorPeers[0].Key(),
expectedIPs: []netip.Prefix{
transitIPs[0],
transitIPs[1],
},
expectedOk: true,
},
},
},
{
name: "step-3-conn2-tip1",
as: addrs{
dst: netip.MustParseAddr("1.2.3.3"),
magic: netip.MustParseAddr("100.64.0.3"),
transit: transitIPs[2].Addr(),
domain: "hoo.example.com.",
app: "app2",
},
lookups: []lookup{
{
connKey: connectorPeers[0].Key(),
expectedIPs: []netip.Prefix{
transitIPs[0],
transitIPs[1],
},
expectedOk: true,
},
{
connKey: connectorPeers[1].Key(),
expectedIPs: []netip.Prefix{
transitIPs[2],
},
expectedOk: true,
},
},
},
}
for _, tt := range steps {
t.Run(tt.name, func(t *testing.T) {
// Add and enqueue the addrs, and then wait for the send to complete
// (as indicated by authReconfig being called).
if err := ext.conn25.client.assignments.insert(tt.as); err != nil {
t.Fatalf("error inserting address assignment: %v", err)
}
if err := ext.conn25.client.enqueueAddressAssignment(tt.as); err != nil {
t.Fatalf("error enqueuing address assignment: %v", err)
}
select {
case <-authReconfigAsyncCalled:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for AuthReconfigAsync to be called")
}
// Check that each of the lookups behaves as expected
for i, lu := range tt.lookups {
got, ok := ext.conn25.client.assignments.lookupTransitIPsByConnKey(lu.connKey)
if ok != lu.expectedOk {
t.Fatalf("unexpected ok result at index %d wanted %v, got %v", i, lu.expectedOk, ok)
}
slices.SortFunc(got, func(a, b netip.Prefix) int { return a.Compare(b) })
if diff := cmp.Diff(lu.expectedIPs, got, cmpopts.EquateComparable(netip.Prefix{})); diff != "" {
t.Fatalf("transit IPs mismatch at index %d, (-want +got):\n%s", i, diff)
}
}
})
}
}
func TestTransitIPConnMapping(t *testing.T) {
conn25 := newConn25(t.Logf)
as := addrs{
dst: netip.MustParseAddr("1.2.3.1"),
magic: netip.MustParseAddr("100.64.0.1"),
transit: netip.MustParseAddr("169.254.0.1"),
domain: "woo.example.com.",
app: "app1",
}
connectorPeers := []tailcfg.NodeView{
(&tailcfg.Node{
ID: tailcfg.NodeID(0),
Tags: []string{"tag:woo"},
Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(),
Key: key.NodePublic{},
}).View(),
(&tailcfg.Node{
ID: tailcfg.NodeID(2),
Tags: []string{"tag:hoo"},
Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(),
Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x02})),
}).View(),
}
// Adding a transit IP that isn't known should fail
if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err == nil {
t.Error("adding an unknown transit IP should fail")
}
// Insert the address assignments
conn25.client.assignments.insert(as)
// Adding a transit IP for a node with an unset key should fail
if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[0]); err == nil {
t.Error("adding an transit IP mapping for a connector with a zero key should fail")
}
// Adding a transit IP that is known should succeed
if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err != nil {
t.Errorf("unexpected error for first time add: %v", err)
}
// But doing it again should fail
if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err == nil {
t.Error("adding a duplicate transitIP for a connector should fail")
}
}
func TestClientTransitIPForMagicIP(t *testing.T) {
sn := makeSelfNode(t, []appctype.Conn25Attr{{
MagicIPPool: []netipx.IPRange{rangeFrom("0", "10")}, // 100.64.0.0 - 100.64.0.10
}}, []string{})
mappedMip := netip.MustParseAddr("100.64.0.0")
mappedTip := netip.MustParseAddr("169.0.0.0")
unmappedMip := netip.MustParseAddr("100.64.0.1")
@@ -1253,9 +1506,9 @@ func TestClientTransitIPForMagicIP(t *testing.T) {
}
func TestConnectorRealIPForTransitIPConnection(t *testing.T) {
sn := makeSelfNode(t, appctype.Conn25Attr{
sn := makeSelfNode(t, []appctype.Conn25Attr{{
TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")}, // 100.64.0.40 - 100.64.0.50
}, []string{})
}}, []string{})
mappedSrc := netip.MustParseAddr("100.0.0.1")
unmappedSrc := netip.MustParseAddr("100.0.0.2")
mappedTip := netip.MustParseAddr("100.64.0.41")