From e8ecbd7af17e4c25c332e9684de8bb5d9b2cdd9d Mon Sep 17 00:00:00 2001 From: Florian Schade Date: Fri, 20 Feb 2026 16:52:29 +0100 Subject: [PATCH] refactor: make the logout mode private --- services/proxy/pkg/middleware/oidc_auth.go | 7 ++- .../pkg/staticroutes/backchannellogout.go | 5 +-- .../backchannellogout/backchannellogout.go | 41 +++++++++-------- .../backchannellogout_test.go | 44 +++---------------- 4 files changed, 33 insertions(+), 64 deletions(-) diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index 3a3ed29fdc..00822cdc99 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" "github.com/vmihailenco/msgpack/v5" "go-micro.dev/v4/store" @@ -117,13 +116,13 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri m.Logger.Error().Err(err).Msg("failed to write to userinfo cache") } - // if the claim has no subject, we can leave it empty, - // it's important to keep the dot in the key to prevent - // sufix and prefix exploration in the cache. + // fail if creating the storage key fails, + // it means there is no subject and no session. // // ok: {key: ".sessionId"} // ok: {key: "subject."} // ok: {key: "subject.sessionId"} + // fail: {key: "."} subjectSessionKey, err := staticroutes.NewRecordKey(aClaims.Subject, aClaims.SessionID) if err != nil { m.Logger.Error().Err(err).Msg("failed to build subject.session") diff --git a/services/proxy/pkg/staticroutes/backchannellogout.go b/services/proxy/pkg/staticroutes/backchannellogout.go index d7e10d5da8..53375d63bb 100644 --- a/services/proxy/pkg/staticroutes/backchannellogout.go +++ b/services/proxy/pkg/staticroutes/backchannellogout.go @@ -75,10 +75,7 @@ func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re return } - // find out which mode of backchannel logout we are in - // by checking if the session or subject is present in the token - logoutMode := bcl.GetLogoutMode(requestSubjectAndSession) - lookupRecords, err := bcl.GetLogoutRecords(requestSubjectAndSession, logoutMode, s.UserInfoCache) + lookupRecords, err := bcl.GetLogoutRecords(requestSubjectAndSession, s.UserInfoCache) if errors.Is(err, microstore.ErrNotFound) || len(lookupRecords) == 0 { render.Status(r, http.StatusOK) render.JSON(w, r, nil) diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go index 1863047031..86ee00556b 100644 --- a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go @@ -97,27 +97,27 @@ func NewSuSe(key string) (SuSe, error) { return suse, nil } -// LogoutMode defines the mode of backchannel logout, either by session or by subject -type LogoutMode int +// logoutMode defines the mode of backchannel logout, either by session or by subject +type logoutMode int const ( - // LogoutModeUndefined is used when the logout mode cannot be determined - LogoutModeUndefined LogoutMode = iota - // LogoutModeSubject is used when the logout mode is determined by the subject - LogoutModeSubject - // LogoutModeSession is used when the logout mode is determined by the session id - LogoutModeSession + // logoutModeUndefined is used when the logout mode cannot be determined + logoutModeUndefined logoutMode = iota + // logoutModeSubject is used when the logout mode is determined by the subject + logoutModeSubject + // logoutModeSession is used when the logout mode is determined by the session id + logoutModeSession ) -// GetLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct -func GetLogoutMode(suse SuSe) LogoutMode { +// getLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct +func getLogoutMode(suse SuSe) logoutMode { switch { case suse.encodedSession == "" && suse.encodedSubject != "": - return LogoutModeSubject + return logoutModeSubject case suse.encodedSession != "": - return LogoutModeSession + return logoutModeSession default: - return LogoutModeUndefined + return logoutModeUndefined } } @@ -128,16 +128,19 @@ var ErrSuspiciousCacheResult = errors.New("suspicious cache result") // logout mode and the provided SuSe struct. // it uses a seperator to prevent sufix and prefix exploration in the cache and checks // if the retrieved records match the requested subject and or session id as well, to prevent false positives. -func GetLogoutRecords(suse SuSe, mode LogoutMode, store microstore.Store) ([]*microstore.Record, error) { +func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, error) { + // get subject.session mode + mode := getLogoutMode(suse) + var key string var opts []microstore.ReadOption switch mode { - case LogoutModeSubject: + case logoutModeSubject: // the dot at the end prevents prefix exploration in the cache, // so only keys that start with 'subject.*' will be returned, but not 'sub*'. key = suse.encodedSubject + "." opts = append(opts, microstore.ReadPrefix()) - case LogoutModeSession: + case logoutModeSession: // the dot at the beginning prevents sufix exploration in the cache, // so only keys that end with '*.session' will be returned, but not '*sion'. key = "." + suse.encodedSession @@ -156,7 +159,7 @@ func GetLogoutRecords(suse SuSe, mode LogoutMode, store microstore.Store) ([]*mi return nil, microstore.ErrNotFound } - if mode == LogoutModeSession && len(records) > 1 { + if mode == logoutModeSession && len(records) > 1 { return nil, errors.Join(errors.New("multiple session records found"), ErrSuspiciousCacheResult) } @@ -171,10 +174,10 @@ func GetLogoutRecords(suse SuSe, mode LogoutMode, store microstore.Store) ([]*mi switch { // in subject mode, the subject must match, but the session id can be different - case mode == LogoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject: + case mode == logoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject: continue // in session mode, the session id must match, but the subject can be different - case mode == LogoutModeSession && suse.encodedSession == recordSuSe.encodedSession: + case mode == logoutModeSession && suse.encodedSession == recordSuSe.encodedSession: continue } diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go index a653be2247..617bd6d9e0 100644 --- a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go @@ -137,33 +137,33 @@ func TestGetLogoutMode(t *testing.T) { tests := []struct { name string suSe SuSe - want LogoutMode + want logoutMode }{ { name: "key variation: '.session'", suSe: mustNewSuSe(t, "", "session"), - want: LogoutModeSession, + want: logoutModeSession, }, { name: "key variation: 'subject.session'", suSe: mustNewSuSe(t, "subject", "session"), - want: LogoutModeSession, + want: logoutModeSession, }, { name: "key variation: 'subject.'", suSe: mustNewSuSe(t, "subject", ""), - want: LogoutModeSubject, + want: logoutModeSubject, }, { name: "key variation: 'empty'", suSe: SuSe{}, - want: LogoutModeUndefined, + want: logoutModeUndefined, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mode := GetLogoutMode(tt.suSe) + mode := getLogoutMode(tt.suSe) require.Equal(t, tt.want, mode) }) } @@ -197,34 +197,13 @@ func TestGetLogoutRecords(t *testing.T) { tests := []struct { name string suSe SuSe - mode LogoutMode store func(t *testing.T) store.Store wantRecords []*store.Record wantErrs []error }{ - { - name: "fails if mode is unknown", - suSe: mustNewSuSe(t, "", "session-a"), - mode: LogoutModeUndefined, - store: func(t *testing.T) store.Store { - return sessionStore - }, - wantRecords: []*store.Record{}, - wantErrs: []error{ErrSuspiciousCacheResult}, - }, - { - name: "fails if mode is any random int", - suSe: mustNewSuSe(t, "", "session-a"), - mode: 999, - store: func(t *testing.T) store.Store { - return sessionStore - }, - wantRecords: []*store.Record{}, - wantErrs: []error{ErrSuspiciousCacheResult}}, { name: "fails if multiple session records are found", suSe: mustNewSuSe(t, "", "session-a"), - mode: LogoutModeSession, store: func(t *testing.T) store.Store { s := mocks.NewStore(t) s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ @@ -238,7 +217,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "fails if the record key is not ok", suSe: mustNewSuSe(t, "", "session-a"), - mode: LogoutModeSession, store: func(t *testing.T) store.Store { s := mocks.NewStore(t) s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ @@ -252,7 +230,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "fails if the session does not match the retrieved record", suSe: mustNewSuSe(t, "", "session-a"), - mode: LogoutModeSession, store: func(t *testing.T) store.Store { s := mocks.NewStore(t) s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ @@ -265,7 +242,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "fails if the subject does not match the retrieved record", suSe: mustNewSuSe(t, "subject-a", ""), - mode: LogoutModeSubject, store: func(t *testing.T) store.Store { s := mocks.NewStore(t) s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{ @@ -279,7 +255,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "key variation: 'session-a'", suSe: mustNewSuSe(t, "", "session-a"), - mode: LogoutModeSession, store: func(*testing.T) store.Store { return sessionStore }, @@ -288,7 +263,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "key variation: 'session-b'", suSe: mustNewSuSe(t, "", "session-b"), - mode: LogoutModeSession, store: func(*testing.T) store.Store { return sessionStore }, @@ -297,7 +271,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "key variation: 'session-c'", suSe: mustNewSuSe(t, "", "session-c"), - mode: LogoutModeSession, store: func(*testing.T) store.Store { return sessionStore }, @@ -306,7 +279,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "key variation: 'ession-c'", suSe: mustNewSuSe(t, "", "ession-c"), - mode: LogoutModeSession, store: func(*testing.T) store.Store { return sessionStore }, @@ -316,7 +288,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "key variation: 'subject-a'", suSe: mustNewSuSe(t, "subject-a", ""), - mode: LogoutModeSubject, store: func(*testing.T) store.Store { return sessionStore }, @@ -325,7 +296,6 @@ func TestGetLogoutRecords(t *testing.T) { { name: "key variation: 'subject-'", suSe: mustNewSuSe(t, "subject-", ""), - mode: LogoutModeSubject, store: func(*testing.T) store.Store { return sessionStore }, @@ -336,7 +306,7 @@ func TestGetLogoutRecords(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - records, err := GetLogoutRecords(tt.suSe, tt.mode, tt.store(t)) + records, err := GetLogoutRecords(tt.suSe, tt.store(t)) for _, wantErr := range tt.wantErrs { require.ErrorIs(t, err, wantErr) }