diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 9d5a7d2a8..cb56f701b 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -192,9 +192,9 @@ func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() for c := range srv.activeConns { - if c.info == nil { - // c.info is nil when the connection hasn't been authenticated yet. - // In that case, the connection will be terminated when it is. + if !c.authCompleted.Load() { + // The connection hasn't completed authentication yet. + // In that case, the connection will be terminated when it does. continue } go c.checkStillValid() @@ -236,14 +236,26 @@ type conn struct { // Banners cannot be sent after auth completes. spac gossh.ServerPreAuthConn + // The following fields are set during clientAuth and are used for policy + // evaluation and session management. They are immutable after clientAuth + // completes. They must not be read from other goroutines until + // authCompleted is set to true. + action0 *tailcfg.SSHAction // set by clientAuth finalAction *tailcfg.SSHAction // set by clientAuth - info *sshConnInfo // set by setInfo + info *sshConnInfo // set by setInfo (during clientAuth) localUser *userMeta // set by clientAuth userGroupIDs []string // set by clientAuth acceptEnv []string + // authCompleted is set to true after clientAuth has finished writing + // all authentication state fields (info, localUser, action0, + // finalAction, userGroupIDs, acceptEnv). It provides a memory + // barrier so that concurrent readers (e.g. OnPolicyChange) see + // fully-initialized values. + authCompleted atomic.Bool + // mu protects the following fields. // // srv.mu should be acquired prior to mu. @@ -369,6 +381,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE } } c.finalAction = action + c.authCompleted.Store(true) return &gossh.Permissions{}, nil case action.Reject: metricTerminalReject.Add(1) diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 44db0cc00..6d9d859a2 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -31,6 +31,7 @@ "sync" "sync/atomic" "testing" + "testing/synctest" "time" gossh "golang.org/x/crypto/ssh" @@ -1111,6 +1112,7 @@ func TestSSH(t *testing.T) { } sc.action0 = &tailcfg.SSHAction{Accept: true} sc.finalAction = sc.action0 + sc.authCompleted.Store(true) sc.Handler = func(s ssh.Session) { sc.newSSHSession(s).run() @@ -1320,6 +1322,79 @@ func TestStdOsUserUserAssumptions(t *testing.T) { } } +func TestOnPolicyChangeSkipsPreAuthConns(t *testing.T) { + tests := []struct { + name string + sshRule *tailcfg.SSHRule + wantCancel bool + }{ + { + name: "accept-after-auth", + sshRule: newSSHRule(&tailcfg.SSHAction{Accept: true}), + wantCancel: false, + }, + { + name: "reject-after-auth", + sshRule: newSSHRule(&tailcfg.SSHAction{Reject: true}), + wantCancel: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + srv := &server{ + logf: tstest.WhileTestRunningLogger(t), + lb: &localState{ + sshEnabled: true, + matchingRule: tt.sshRule, + }, + } + c := &conn{ + srv: srv, + info: &sshConnInfo{ + sshUser: "alice", + src: netip.MustParseAddrPort("1.2.3.4:30343"), + dst: netip.MustParseAddrPort("100.100.100.102:22"), + }, + localUser: &userMeta{User: user.User{Username: currentUser}}, + } + srv.activeConns = map[*conn]bool{c: true} + ctx, cancel := context.WithCancelCause(context.Background()) + ss := &sshSession{ctx: ctx, cancelCtx: cancel} + c.sessions = []*sshSession{ss} + + // Before authCompleted is set, OnPolicyChange should skip + // the conn entirely — no goroutine spawned. + srv.OnPolicyChange() + synctest.Wait() + select { + case <-ctx.Done(): + t.Fatal("session canceled before auth completed") + default: + } + + // Mark auth as completed. Now OnPolicyChange should + // evaluate the policy and act accordingly. + c.authCompleted.Store(true) + + srv.OnPolicyChange() + synctest.Wait() + select { + case <-ctx.Done(): + if !tt.wantCancel { + t.Fatal("valid session should not have been canceled") + } + default: + if tt.wantCancel { + t.Fatal("invalid session should have been canceled") + } + } + }) + }) + } +} + func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux()