diff --git a/cmd/k8s-proxy/k8s-proxy.go b/cmd/k8s-proxy/k8s-proxy.go index 38a86a5e0..c0797c16a 100644 --- a/cmd/k8s-proxy/k8s-proxy.go +++ b/cmd/k8s-proxy/k8s-proxy.go @@ -31,6 +31,7 @@ "k8s.io/utils/strings/slices" "tailscale.com/client/local" "tailscale.com/cmd/k8s-proxy/internal/config" + "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/store" @@ -41,6 +42,7 @@ "tailscale.com/kube/certs" healthz "tailscale.com/kube/health" "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" klc "tailscale.com/kube/localclient" "tailscale.com/kube/metrics" @@ -171,10 +173,31 @@ func run(logger *zap.SugaredLogger) error { // If Pod UID unset, assume we're running outside of a cluster/not managed // by the operator, so no need to set additional state keys. + var kc kubeclient.Client + var stateSecretName string if podUID != "" { if err := state.SetInitialKeys(st, podUID); err != nil { return fmt.Errorf("error setting initial state: %w", err) } + + if cfg.Parsed.State != nil { + if name, ok := strings.CutPrefix(*cfg.Parsed.State, "kube:"); ok { + stateSecretName = name + + kc, err = newKubeClient(stateSecretName) + if err != nil { + return err + } + + var configAuthKey string + if cfg.Parsed.AuthKey != nil { + configAuthKey = *cfg.Parsed.AuthKey + } + if err := resetState(ctx, kc, stateSecretName, podUID, configAuthKey); err != nil { + return fmt.Errorf("error resetting state: %w", err) + } + } + } } var authKey string @@ -197,23 +220,68 @@ func run(logger *zap.SugaredLogger) error { ts.Hostname = *cfg.Parsed.Hostname } - // Make sure we crash loop if Up doesn't complete in reasonable time. - upCtx, upCancel := context.WithTimeout(ctx, time.Minute) - defer upCancel() - if _, err := ts.Up(upCtx); err != nil { - return fmt.Errorf("error starting tailscale server: %w", err) - } - defer ts.Close() lc, err := ts.LocalClient() if err != nil { return fmt.Errorf("error getting local client: %w", err) } - // Setup for updating state keys. + // Make sure we crash loop if Up doesn't complete in reasonable time. + upCtx, upCancel := context.WithTimeout(ctx, 30*time.Second) + defer upCancel() + + // ts.Up() deliberately ignores NeedsLogin because it fires transiently + // during normal auth-key login. We can watch for the login-state health + // warning here though, which only fires on terminal auth failure, and + // cancel early. + go func() { + w, err := lc.WatchIPNBus(upCtx, ipn.NotifyInitialHealthState) + if err != nil { + return + } + defer w.Close() + for { + n, err := w.Next() + if err != nil { + return + } + if n.Health != nil { + if _, ok := n.Health.Warnings[health.LoginStateWarnable.Code]; ok { + upCancel() + return + } + } + } + }() + + if _, err := ts.Up(upCtx); err != nil { + if kc != nil && stateSecretName != "" { + clearTailscaledState(ctx, kc, stateSecretName) + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + return err + } + + defer ts.Close() + if podUID != "" { group.Go(func() error { return state.KeepKeysUpdated(ctx, st, klc.New(lc)) }) + + if kc != nil && stateSecretName != "" { + needsReissue, err := checkInitialAuthState(ctx, lc) + if err != nil { + return fmt.Errorf("error checking initial auth state: %w", err) + } + if needsReissue { + logger.Info("Auth key missing or invalid after startup, requesting new key from operator") + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + + group.Go(func() error { + return monitorAuthHealth(ctx, lc, kc, stateSecretName, cfgChan, configPath, authKey, logger) + }) + } } if cfg.Parsed.HealthCheckEnabled.EqualBool(true) || cfg.Parsed.MetricsEnabled.EqualBool(true) { diff --git a/cmd/k8s-proxy/kube.go b/cmd/k8s-proxy/kube.go new file mode 100644 index 000000000..f7c01ecf0 --- /dev/null +++ b/cmd/k8s-proxy/kube.go @@ -0,0 +1,193 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "go.uber.org/zap" + "tailscale.com/client/local" + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/kube/authkey" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" +) + +const fieldManager = "tailscale-k8s-proxy" + +// extractStateSecretName extracts the Kubernetes secret name from a state store +// path like "kube:secret-name". +func extractStateSecretName(statePath string) (string, error) { + if !strings.HasPrefix(statePath, "kube:") { + return "", fmt.Errorf("state path %q is not a kube store", statePath) + } + secretName := strings.TrimPrefix(statePath, "kube:") + if secretName == "" { + return "", fmt.Errorf("state path %q has no secret name", statePath) + } + return secretName, nil +} + +// newKubeClient creates a kubeclient for interacting with the state Secret. +func newKubeClient(stateSecretName string) (kubeclient.Client, error) { + kc, err := kubeclient.New(fieldManager) + if err != nil { + return nil, fmt.Errorf("error creating kube client: %w", err) + } + return kc, nil +} + +// resetState clears containerboot/k8s-proxy state from previous runs and sets +// initial values. This ensures the operator doesn't use stale state when a Pod +// is first recreated. +// +// It also clears the reissue_authkey marker if the operator has actioned it +// (i.e., the config now has a different auth key than what was marked for +// reissue). +func resetState(ctx context.Context, kc kubeclient.Client, stateSecretName string, podUID string, configAuthKey string) error { + existingSecret, err := kc.GetSecret(ctx, stateSecretName) + switch { + case kubeclient.IsNotFoundErr(err): + return nil + case err != nil: + return fmt.Errorf("failed to read state Secret %q to reset state: %w", stateSecretName, err) + } + + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyCapVer: fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion), + }, + } + if podUID != "" { + s.Data[kubetypes.KeyPodUID] = []byte(podUID) + } + + // Only clear reissue_authkey if the operator has actioned it. + brokenAuthkey, ok := existingSecret.Data[kubetypes.KeyReissueAuthkey] + if ok && configAuthKey != "" && string(brokenAuthkey) != configAuthKey { + s.Data[kubetypes.KeyReissueAuthkey] = nil + } + + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, fieldManager) +} + +// checkInitialAuthState checks if the tsnet server is in an auth failure state +// immediately after coming up. Returns true if auth key reissue is needed. +func checkInitialAuthState(ctx context.Context, lc *local.Client) (bool, error) { + status, err := lc.Status(ctx) + if err != nil { + return false, fmt.Errorf("error getting status: %w", err) + } + + if status.BackendState == ipn.NeedsLogin.String() { + return true, nil + } + + // Status.Health is a []string of health warnings. + loginWarnableCode := string(health.LoginStateWarnable.Code) + for _, h := range status.Health { + if strings.Contains(h, loginWarnableCode) { + return true, nil + } + } + + return false, nil +} + +// monitorAuthHealth watches the IPN bus for auth failures and triggers reissue +// when needed. Runs until context is cancelled or auth failure is detected. +func monitorAuthHealth(ctx context.Context, lc *local.Client, kc kubeclient.Client, stateSecretName string, cfgChan <-chan *conf.Config, configPath string, authKey string, logger *zap.SugaredLogger) error { + w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialHealthState|ipn.NotifyInitialState) + if err != nil { + return fmt.Errorf("failed to watch IPN bus for auth health: %w", err) + } + defer w.Close() + + for { + n, err := w.Next() + if err != nil { + if err == ctx.Err() { + return nil + } + return err + } + + if n.State != nil && *n.State == ipn.NeedsLogin { + logger.Info("Auth key missing or invalid (NeedsLogin state), disconnecting from control and requesting new key from operator") + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + + if n.Health != nil { + if _, ok := n.Health.Warnings[health.LoginStateWarnable.Code]; ok { + logger.Info("Auth key failed to authenticate (may be expired or single-use), disconnecting from control and requesting new key from operator") + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + } + } +} + +func clearTailscaledState(ctx context.Context, kc kubeclient.Client, stateSecretName string) error { + secret, err := kc.GetSecret(ctx, stateSecretName) + if err != nil { + return fmt.Errorf("error reading state Secret: %w", err) + } + + s := &kubeapi.Secret{ + Data: map[string][]byte{ + "_machinekey": nil, + "_current-profile": nil, + }, + } + + // The profile key name is stored in _current-profile (e.g. "profile-a716"). + if profileKey := string(secret.Data["_current-profile"]); profileKey != "" { + s.Data[profileKey] = nil + } + + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, fieldManager) +} + +// handleAuthKeyReissue orchestrates the auth key reissue flow: +// 1. Disconnect from control +// 2. Set reissue marker in state Secret +// 3. Wait for operator to provide new key +// 4. Exit cleanly (Kubernetes will restart the pod with the new key) +func handleAuthKeyReissue(ctx context.Context, lc *local.Client, kc kubeclient.Client, stateSecretName string, currentAuthKey string, cfgChan <-chan *conf.Config, logger *zap.SugaredLogger) error { + if err := lc.DisconnectControl(ctx); err != nil { + return fmt.Errorf("error disconnecting from control: %w", err) + } + if err := authkey.SetReissueAuthKey(ctx, kc, stateSecretName, currentAuthKey); err != nil { + return fmt.Errorf("failed to set reissue_authkey in Kubernetes Secret: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) + defer cancel() + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for auth key reissue") + case cfg := <-cfgChan: + if cfg.Parsed.AuthKey != nil && *cfg.Parsed.AuthKey != currentAuthKey { + if err := authkey.ClearReissueAuthKey(ctx, kc, stateSecretName); err != nil { + logger.Warnf("failed to clear reissue request: %v", err) + } + logger.Info("Successfully received new auth key, restarting to apply configuration") + err := clearTailscaledState(ctx, kc, stateSecretName) + if err != nil { + return fmt.Errorf("failed to clear tailscaled state: %w", err) + } + return nil + } + } + } +} diff --git a/cmd/k8s-proxy/kube_test.go b/cmd/k8s-proxy/kube_test.go new file mode 100644 index 000000000..0a37f1ab1 --- /dev/null +++ b/cmd/k8s-proxy/kube_test.go @@ -0,0 +1,153 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" +) + +func TestExtractStateSecretName(t *testing.T) { + tests := []struct { + name string + input string + want string + wantError bool + }{ + { + name: "valid_kube_path", + input: "kube:tailscale-state", + want: "tailscale-state", + }, + { + name: "valid_kube_path_with_namespace", + input: "kube:tailscale-state-ns", + want: "tailscale-state-ns", + }, + { + name: "non_kube_path", + input: "mem:", + wantError: true, + }, + { + name: "empty_secret_name", + input: "kube:", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractStateSecretName(tt.input) + if tt.wantError { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestResetState(t *testing.T) { + tests := []struct { + name string + existingData map[string][]byte + podUID string + configAuthKey string + wantPatched map[string][]byte + }{ + { + name: "clears_device_state", + existingData: map[string][]byte{ + kubetypes.KeyDeviceID: []byte("device-123"), + kubetypes.KeyDeviceFQDN: []byte("node.tailnet"), + kubetypes.KeyDeviceIPs: []byte(`["100.64.0.1"]`), + }, + podUID: "pod-123", + configAuthKey: "new-key", + wantPatched: map[string][]byte{ + kubetypes.KeyCapVer: []byte("95"), + kubetypes.KeyPodUID: []byte("pod-123"), + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + }, + }, + { + name: "clears_reissue_marker_when_actioned", + existingData: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-key"), + }, + podUID: "pod-123", + configAuthKey: "new-key", + wantPatched: map[string][]byte{ + kubetypes.KeyCapVer: []byte("95"), + kubetypes.KeyPodUID: []byte("pod-123"), + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + kubetypes.KeyReissueAuthkey: nil, + }, + }, + { + name: "keeps_reissue_marker_when_not_actioned", + existingData: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-key"), + }, + podUID: "pod-123", + configAuthKey: "old-key", + wantPatched: map[string][]byte{ + kubetypes.KeyCapVer: []byte("95"), + kubetypes.KeyPodUID: []byte("pod-123"), + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Adjust expected cap ver to match actual current version. + tt.wantPatched[kubetypes.KeyCapVer] = []byte{0} + tt.wantPatched[kubetypes.KeyCapVer] = fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion) + + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{Data: tt.existingData}, nil + }, + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, s *kubeapi.Secret, fm string) error { + patched = s.Data + return nil + }, + } + + err := resetState(context.Background(), kc, "test-secret", tt.podUID, tt.configAuthKey) + if err != nil { + t.Fatalf("resetState() error = %v", err) + } + + if diff := cmp.Diff(tt.wantPatched, patched); diff != "" { + t.Errorf("resetState() mismatch (-want +got):\n%s", diff) + } + }) + } +}