diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 2b6f31291..bb1c8c26a 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -139,6 +139,7 @@ "github.com/benbjohnson/immutable" "golang.org/x/sys/unix" + "tailscale.com/client/local" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -212,6 +213,23 @@ func (s netmapState) updateFromNotify(n ipn.Notify) netmapState { return s } +// processNotify updates the netmap state from an IPN bus Notify. On +// SelfChange it also refetches DNS via the LocalAPI dns-config +// endpoint; the bus carries no DNS delta. +func (s netmapState) processNotify(ctx context.Context, client *local.Client, n ipn.Notify) netmapState { + s = s.updateFromNotify(n) + if n.SelfChange != nil { + dns, err := client.DNSConfig(ctx) + if err != nil { + log.Printf("error refreshing DNS config from tailscaled: %v", err) + } else if dns != nil { + s.dnsExtraRecords = views.SliceOf(dns.ExtraRecords) + s.certDomains = views.SliceOf(dns.CertDomains) + } + } + return s +} + func (s netmapState) updateFromStatus(st *ipnstate.Status) netmapState { s.certDomains = views.SliceOf(st.CertDomains) s.dnsExtraRecords = views.SliceOf(st.ExtraRecords) @@ -703,7 +721,7 @@ func run() error { case err := <-cfgWatchErrChan: return fmt.Errorf("failed to watch tailscaled config: %w", err) case n := <-notifyChan: - nmState = nmState.updateFromNotify(n) + nmState = nmState.processNotify(ctx, client, n) if state, ok := notifyState(n); ok && state != ipn.Running { // Something's gone wrong and we've left the authenticated state. // Our container image never recovered gracefully from this, and the diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 6b64f3c43..0013bb8e4 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -7,6 +7,7 @@ import ( "bytes" + "context" _ "embed" "encoding/base64" "encoding/json" @@ -32,6 +33,7 @@ "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" + "tailscale.com/client/local" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/health" "tailscale.com/ipn" @@ -39,6 +41,7 @@ "tailscale.com/kube/egressservices" "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" + "tailscale.com/net/memnet" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/key" @@ -1945,3 +1948,57 @@ func newTestEnv(t *testing.T) testEnv { healthAddrPort: healthAddrPort, } } + +// TestProcessNotifyRefreshesDNSOnSelfChange verifies that a SelfChange +// notification triggers a DNS refresh; without it, VIPServices created +// after pod boot are invisible to resolveTailnetFQDN. +func TestProcessNotifyRefreshesDNSOnSelfChange(t *testing.T) { + extraRec := tailcfg.DNSRecord{ + Name: "my-ingress.tailnet.ts.net.", + Type: "A", + Value: "100.99.10.20", + } + dnsCfg := &tailcfg.DNSConfig{ + ExtraRecords: []tailcfg.DNSRecord{extraRec}, + CertDomains: []string{"node.tailnet.ts.net"}, + } + + lal := memnet.Listen("local-tailscaled.sock:80") + defer lal.Close() + mux := http.NewServeMux() + mux.HandleFunc("/localapi/v0/dns-config", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(dnsCfg); err != nil { + t.Errorf("encoding dns config: %v", err) + } + }) + srv := &http.Server{Handler: mux} + go srv.Serve(lal) + t.Cleanup(func() { srv.Shutdown(context.Background()) }) + + client := &local.Client{Dial: lal.Dial} + + // Empty starting state, as if the InitialStatus captured at pod + // boot carried no ExtraRecords because the VIPService didn't exist + // yet at that time. + var s netmapState + + n := ipn.Notify{ + SelfChange: &tailcfg.Node{ + ID: 1, + Name: "self.tailnet.ts.net.", + }, + } + + got := s.processNotify(context.Background(), client, n) + + if got.dnsExtraRecords.Len() != 1 { + t.Fatalf("dnsExtraRecords.Len() = %d, want 1", got.dnsExtraRecords.Len()) + } + if rec := got.dnsExtraRecords.At(0); rec.Name != extraRec.Name { + t.Errorf("dnsExtraRecords[0].Name = %q, want %q", rec.Name, extraRec.Name) + } + if got.certDomains.Len() != 1 || got.certDomains.At(0) != "node.tailnet.ts.net" { + t.Errorf("certDomains = %v, want [node.tailnet.ts.net]", got.certDomains.AsSlice()) + } +}