diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index 764634d30..027e7c810 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -37,7 +37,6 @@ "tailscale.com/feature/buildfeatures" "tailscale.com/hostinfo" "tailscale.com/ipn" - "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store" "tailscale.com/ipn/store/mem" "tailscale.com/net/bakedroots" @@ -106,6 +105,13 @@ func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertK // // If a cert is expired, or expires sooner than minValidity, it will be renewed // synchronously. Otherwise it will be renewed asynchronously. +// +// The domain must be one of: +// +// - An exact CertDomain (e.g., "node.ts.net") +// - A wildcard domain (e.g., "*.node.ts.net") +// +// The wildcard format requires the NodeAttrDNSSubdomainResolve capability. func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string, minValidity time.Duration) (*TLSCertKeyPair, error) { b.mu.Lock() getCertForTest := b.getCertForTest @@ -119,6 +125,13 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if !validLookingCertDomain(domain) { return nil, errors.New("invalid domain") } + + certDomain, err := b.resolveCertDomain(domain) + if err != nil { + return nil, err + } + storageKey := strings.TrimPrefix(certDomain, "*.") + logf := logger.WithPrefix(b.logf, fmt.Sprintf("cert(%q): ", domain)) now := b.clock.Now() traceACME := func(v any) { @@ -134,13 +147,13 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string return nil, err } - if pair, err := getCertPEMCached(cs, domain, now); err == nil { + if pair, err := getCertPEMCached(cs, storageKey, now); err == nil { if envknob.IsCertShareReadOnlyMode() { return pair, nil } // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) + shouldRenew, err := b.shouldStartDomainRenewal(cs, storageKey, now, pair, minValidity) if err != nil { logf("error checking for certificate renewal: %v", err) // Renewal check failed, but the current cert is valid and not @@ -154,7 +167,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string logf("starting async renewal") // Start renewal in the background, return current valid cert. b.goTracker.Go(func() { - if _, err := getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity); err != nil { + if _, err := getCertPEM(context.Background(), b, cs, logf, traceACME, certDomain, now, minValidity); err != nil { logf("async renewal failed: getCertPem: %v", err) } }) @@ -169,7 +182,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err) } - pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) + pair, err := getCertPEM(ctx, b, cs, logf, traceACME, certDomain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -506,19 +519,24 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey } // getCertPem checks if a cert needs to be renewed and if so, renews it. +// domain is the resolved cert domain (e.g., "*.node.ts.net" for wildcards). // It can be overridden in tests. var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() + // storageKey is used for file storage and renewal tracking. + // For wildcards, "*.node.ts.net" -> "node.ts.net" + storageKey, isWildcard := strings.CutPrefix(domain, "*.") + // In case this method was triggered multiple times in parallel (when // serving incoming requests), check whether one of the other goroutines // already renewed the cert before us. - previous, err := getCertPEMCached(cs, domain, now) + previous, err := getCertPEMCached(cs, storageKey, now) if err == nil { // shouldStartDomainRenewal caches its result so it's OK to call this // frequently. - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, previous, minValidity) + shouldRenew, err := b.shouldStartDomainRenewal(cs, storageKey, now, previous, minValidity) if err != nil { logf("error checking for certificate renewal: %v", err) } else if !shouldRenew { @@ -561,12 +579,6 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return nil, fmt.Errorf("unexpected ACME account status %q", a.Status) } - // Before hitting LetsEncrypt, see if this is a domain that Tailscale will do DNS challenges for. - st := b.StatusWithoutPeers() - if err := checkCertDomain(st, domain); err != nil { - return nil, err - } - // If we have a previous cert, include it in the order. Assuming we're // within the ARI renewal window this should exclude us from LE rate // limits. @@ -580,7 +592,18 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey opts = append(opts, acme.WithOrderReplacesCert(prevCrt)) } } - order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: domain}}, opts...) + + // For wildcards, we need to authorize both the wildcard and base domain. + var authzIDs []acme.AuthzID + if isWildcard { + authzIDs = []acme.AuthzID{ + {Type: "dns", Value: domain}, + {Type: "dns", Value: storageKey}, + } + } else { + authzIDs = []acme.AuthzID{{Type: "dns", Value: domain}} + } + order, err := ac.AuthorizeOrder(ctx, authzIDs, opts...) if err != nil { return nil, err } @@ -598,7 +621,9 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey if err != nil { return nil, err } - key := "_acme-challenge." + domain + // For wildcards, the challenge is on the base domain. + // e.g., "*.node.ts.net" -> "_acme-challenge.node.ts.net" + key := "_acme-challenge." + strings.TrimPrefix(az.Identifier.Value, "*.") // Do a best-effort lookup to see if we've already created this DNS name // in a previous attempt. Don't burn too much time on it, though. Worst @@ -608,14 +633,14 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey txts, _ := resolver.LookupTXT(lookupCtx, key) lookupCancel() if slices.Contains(txts, rec) { - logf("TXT record already existed") + logf("TXT record already existed for %s", key) } else { - logf("starting SetDNS call...") + logf("starting SetDNS call for %s...", key) err = b.SetDNS(ctx, key, rec) if err != nil { return nil, fmt.Errorf("SetDNS %q => %q: %w", key, rec, err) } - logf("did SetDNS") + logf("did SetDNS for %s", key) } chal, err := ac.Accept(ctx, ch) @@ -672,19 +697,27 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return nil, err } } - if err := cs.WriteTLSCertAndKey(domain, certPEM.Bytes(), privPEM.Bytes()); err != nil { + if err := cs.WriteTLSCertAndKey(storageKey, certPEM.Bytes(), privPEM.Bytes()); err != nil { return nil, err } - b.domainRenewed(domain) + b.domainRenewed(storageKey) return &TLSCertKeyPair{CertPEM: certPEM.Bytes(), KeyPEM: privPEM.Bytes()}, nil } -// certRequest generates a CSR for the given common name cn and optional SANs. -func certRequest(key crypto.Signer, name string, ext []pkix.Extension) ([]byte, error) { +// certRequest generates a CSR for the given domain and optional SANs. +func certRequest(key crypto.Signer, domain string, ext []pkix.Extension) ([]byte, error) { + dnsNames := []string{domain} + if base, ok := strings.CutPrefix(domain, "*."); ok { + // Wildcard cert must also include the base domain as a SAN. + // This is load-bearing: getCertPEMCached validates certs using + // the storage key (base domain), which only passes x509 verification + // if the base domain is in DNSNames. + dnsNames = append(dnsNames, base) + } req := &x509.CertificateRequest{ - Subject: pkix.Name{CommonName: name}, - DNSNames: []string{name}, + Subject: pkix.Name{CommonName: domain}, + DNSNames: dnsNames, ExtraExtensions: ext, } return x509.CreateCertificateRequest(rand.Reader, req, key) @@ -844,7 +877,7 @@ func isDefaultDirectoryURL(u string) bool { // we might be able to get a cert for. // // It's a light check primarily for double checking before it's used -// as part of a filesystem path. The actual validation happens in checkCertDomain. +// as part of a filesystem path. The actual validation happens in resolveCertDomain. func validLookingCertDomain(name string) bool { if name == "" || strings.Contains(name, "..") || @@ -852,22 +885,56 @@ func validLookingCertDomain(name string) bool { !strings.Contains(name, ".") { return false } + // Only allow * as a wildcard prefix "*.domain.tld" + if rest, ok := strings.CutPrefix(name, "*."); ok { + if strings.Contains(rest, "*") || !strings.Contains(rest, ".") { + return false + } + } else if strings.Contains(name, "*") { + return false + } return true } -func checkCertDomain(st *ipnstate.Status, domain string) error { +// resolveCertDomain validates a domain and returns the cert domain to use. +// +// - "node.ts.net" -> "node.ts.net" (exact CertDomain match) +// - "*.node.ts.net" -> "*.node.ts.net" (explicit wildcard, requires NodeAttrDNSSubdomainResolve) +// +// Subdomain requests like "app.node.ts.net" are rejected; callers should +// request "*.node.ts.net" explicitly for subdomain coverage. +func (b *LocalBackend) resolveCertDomain(domain string) (string, error) { if domain == "" { - return errors.New("missing domain name") + return "", errors.New("missing domain name") } - for _, d := range st.CertDomains { - if d == domain { - return nil + + // Read the netmap once to get both CertDomains and capabilities atomically. + nm := b.NetMap() + if nm == nil { + return "", errors.New("no netmap available") + } + certDomains := nm.DNS.CertDomains + if len(certDomains) == 0 { + return "", errors.New("your Tailscale account does not support getting TLS certs") + } + + // Wildcard request like "*.node.ts.net". + if base, ok := strings.CutPrefix(domain, "*."); ok { + if !nm.AllCaps.Contains(tailcfg.NodeAttrDNSSubdomainResolve) { + return "", fmt.Errorf("wildcard certificates are not enabled for this node") } + if !slices.Contains(certDomains, base) { + return "", fmt.Errorf("invalid domain %q; parent domain must be one of %q", domain, certDomains) + } + return domain, nil } - if len(st.CertDomains) == 0 { - return errors.New("your Tailscale account does not support getting TLS certs") + + // Exact CertDomain match. + if slices.Contains(certDomains, domain) { + return domain, nil } - return fmt.Errorf("invalid domain %q; must be one of %q", domain, st.CertDomains) + + return "", fmt.Errorf("invalid domain %q; must be one of %q", domain, certDomains) } // handleC2NTLSCertStatus returns info about the last TLS certificate issued for the @@ -884,7 +951,7 @@ func handleC2NTLSCertStatus(b *LocalBackend, w http.ResponseWriter, r *http.Requ return } - domain := r.FormValue("domain") + domain := strings.TrimPrefix(r.FormValue("domain"), "*.") if domain == "" { http.Error(w, "no 'domain'", http.StatusBadRequest) return diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index ec7be570c..b8acb710a 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -17,17 +17,205 @@ "math/big" "os" "path/filepath" + "slices" "testing" "time" "github.com/google/go-cmp/cmp" "tailscale.com/envknob" "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/types/netmap" "tailscale.com/util/must" + "tailscale.com/util/set" ) +func TestCertRequest(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + tests := []struct { + domain string + wantSANs []string + }{ + { + domain: "example.com", + wantSANs: []string{"example.com"}, + }, + { + domain: "*.example.com", + wantSANs: []string{"*.example.com", "example.com"}, + }, + { + domain: "*.foo.bar.com", + wantSANs: []string{"*.foo.bar.com", "foo.bar.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + csrDER, err := certRequest(key, tt.domain, nil) + if err != nil { + t.Fatalf("certRequest: %v", err) + } + csr, err := x509.ParseCertificateRequest(csrDER) + if err != nil { + t.Fatalf("ParseCertificateRequest: %v", err) + } + if csr.Subject.CommonName != tt.domain { + t.Errorf("CommonName = %q, want %q", csr.Subject.CommonName, tt.domain) + } + if !slices.Equal(csr.DNSNames, tt.wantSANs) { + t.Errorf("DNSNames = %v, want %v", csr.DNSNames, tt.wantSANs) + } + }) + } +} + +func TestResolveCertDomain(t *testing.T) { + tests := []struct { + name string + domain string + certDomains []string + hasCap bool + skipNetmap bool + want string + wantErr string + }{ + { + name: "exact_match", + domain: "node.ts.net", + certDomains: []string{"node.ts.net"}, + want: "node.ts.net", + }, + { + name: "exact_match_with_cap", + domain: "node.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: true, + want: "node.ts.net", + }, + { + name: "wildcard_with_cap", + domain: "*.node.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: true, + want: "*.node.ts.net", + }, + { + name: "wildcard_without_cap", + domain: "*.node.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: false, + wantErr: "wildcard certificates are not enabled for this node", + }, + { + name: "subdomain_with_cap_rejected", + domain: "app.node.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: true, + wantErr: `invalid domain "app.node.ts.net"; must be one of ["node.ts.net"]`, + }, + { + name: "subdomain_without_cap_rejected", + domain: "app.node.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: false, + wantErr: `invalid domain "app.node.ts.net"; must be one of ["node.ts.net"]`, + }, + { + name: "multi_level_subdomain_rejected", + domain: "a.b.node.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: true, + wantErr: `invalid domain "a.b.node.ts.net"; must be one of ["node.ts.net"]`, + }, + { + name: "wildcard_no_matching_parent", + domain: "*.unrelated.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: true, + wantErr: `invalid domain "*.unrelated.ts.net"; parent domain must be one of ["node.ts.net"]`, + }, + { + name: "subdomain_unrelated_rejected", + domain: "app.unrelated.ts.net", + certDomains: []string{"node.ts.net"}, + hasCap: true, + wantErr: `invalid domain "app.unrelated.ts.net"; must be one of ["node.ts.net"]`, + }, + { + name: "no_cert_domains", + domain: "node.ts.net", + certDomains: nil, + wantErr: "your Tailscale account does not support getting TLS certs", + }, + { + name: "wildcard_no_cert_domains", + domain: "*.foo.ts.net", + certDomains: nil, + hasCap: true, + wantErr: "your Tailscale account does not support getting TLS certs", + }, + { + name: "empty_domain", + domain: "", + certDomains: []string{"node.ts.net"}, + wantErr: "missing domain name", + }, + { + name: "nil_netmap", + domain: "node.ts.net", + skipNetmap: true, + wantErr: "no netmap available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := newTestLocalBackend(t) + + if !tt.skipNetmap { + // Set up netmap with CertDomains and capability + var allCaps set.Set[tailcfg.NodeCapability] + if tt.hasCap { + allCaps = set.Of(tailcfg.NodeAttrDNSSubdomainResolve) + } + b.mu.Lock() + b.currentNode().SetNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{}).View(), + DNS: tailcfg.DNSConfig{ + CertDomains: tt.certDomains, + }, + AllCaps: allCaps, + }) + b.mu.Unlock() + } + + got, err := b.resolveCertDomain(tt.domain) + if tt.wantErr != "" { + if err == nil { + t.Errorf("resolveCertDomain(%q) = %q, want error %q", tt.domain, got, tt.wantErr) + } else if err.Error() != tt.wantErr { + t.Errorf("resolveCertDomain(%q) error = %q, want %q", tt.domain, err.Error(), tt.wantErr) + } + return + } + if err != nil { + t.Errorf("resolveCertDomain(%q) error = %v, want nil", tt.domain, err) + return + } + if got != tt.want { + t.Errorf("resolveCertDomain(%q) = %q, want %q", tt.domain, got, tt.want) + } + }) + } +} + func TestValidLookingCertDomain(t *testing.T) { tests := []struct { in string @@ -40,6 +228,16 @@ func TestValidLookingCertDomain(t *testing.T) { {"", false}, {"foo\\bar.com", false}, {"foo\x00bar.com", false}, + // Wildcard tests + {"*.foo.com", true}, + {"*.foo.bar.com", true}, + {"*foo.com", false}, // must be *. + {"*.com", false}, // must have domain after *. + {"*.", false}, // must have domain after *. + {"*.*.foo.com", false}, // no nested wildcards + {"foo.*.bar.com", false}, // no wildcard mid-string + {"app.foo.com", true}, // regular subdomain + {"*", false}, // bare asterisk } for _, tt := range tests { if got := validLookingCertDomain(tt.in); got != tt.want { @@ -231,12 +429,19 @@ func TestDebugACMEDirectoryURL(t *testing.T) { func TestGetCertPEMWithValidity(t *testing.T) { const testDomain = "example.com" - b := &LocalBackend{ - store: &mem.Store{}, - varRoot: t.TempDir(), - ctx: context.Background(), - logf: t.Logf, - } + b := newTestLocalBackend(t) + b.varRoot = t.TempDir() + + // Set up netmap with CertDomains so resolveCertDomain works + b.mu.Lock() + b.currentNode().SetNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{}).View(), + DNS: tailcfg.DNSConfig{ + CertDomains: []string{testDomain}, + }, + }) + b.mu.Unlock() + certDir, err := b.certDir() if err != nil { t.Fatalf("certDir error: %v", err)