refactor: make the logout mode private

This commit is contained in:
Florian Schade
2026-02-20 16:52:29 +01:00
committed by Christian Richter
parent fd614eacf1
commit e8ecbd7af1
4 changed files with 33 additions and 64 deletions

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"github.com/vmihailenco/msgpack/v5"
"go-micro.dev/v4/store"
@@ -117,13 +116,13 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
m.Logger.Error().Err(err).Msg("failed to write to userinfo cache")
}
// if the claim has no subject, we can leave it empty,
// it's important to keep the dot in the key to prevent
// sufix and prefix exploration in the cache.
// fail if creating the storage key fails,
// it means there is no subject and no session.
//
// ok: {key: ".sessionId"}
// ok: {key: "subject."}
// ok: {key: "subject.sessionId"}
// fail: {key: "."}
subjectSessionKey, err := staticroutes.NewRecordKey(aClaims.Subject, aClaims.SessionID)
if err != nil {
m.Logger.Error().Err(err).Msg("failed to build subject.session")

View File

@@ -75,10 +75,7 @@ func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re
return
}
// find out which mode of backchannel logout we are in
// by checking if the session or subject is present in the token
logoutMode := bcl.GetLogoutMode(requestSubjectAndSession)
lookupRecords, err := bcl.GetLogoutRecords(requestSubjectAndSession, logoutMode, s.UserInfoCache)
lookupRecords, err := bcl.GetLogoutRecords(requestSubjectAndSession, s.UserInfoCache)
if errors.Is(err, microstore.ErrNotFound) || len(lookupRecords) == 0 {
render.Status(r, http.StatusOK)
render.JSON(w, r, nil)

View File

@@ -97,27 +97,27 @@ func NewSuSe(key string) (SuSe, error) {
return suse, nil
}
// LogoutMode defines the mode of backchannel logout, either by session or by subject
type LogoutMode int
// logoutMode defines the mode of backchannel logout, either by session or by subject
type logoutMode int
const (
// LogoutModeUndefined is used when the logout mode cannot be determined
LogoutModeUndefined LogoutMode = iota
// LogoutModeSubject is used when the logout mode is determined by the subject
LogoutModeSubject
// LogoutModeSession is used when the logout mode is determined by the session id
LogoutModeSession
// logoutModeUndefined is used when the logout mode cannot be determined
logoutModeUndefined logoutMode = iota
// logoutModeSubject is used when the logout mode is determined by the subject
logoutModeSubject
// logoutModeSession is used when the logout mode is determined by the session id
logoutModeSession
)
// GetLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct
func GetLogoutMode(suse SuSe) LogoutMode {
// getLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct
func getLogoutMode(suse SuSe) logoutMode {
switch {
case suse.encodedSession == "" && suse.encodedSubject != "":
return LogoutModeSubject
return logoutModeSubject
case suse.encodedSession != "":
return LogoutModeSession
return logoutModeSession
default:
return LogoutModeUndefined
return logoutModeUndefined
}
}
@@ -128,16 +128,19 @@ var ErrSuspiciousCacheResult = errors.New("suspicious cache result")
// logout mode and the provided SuSe struct.
// it uses a seperator to prevent sufix and prefix exploration in the cache and checks
// if the retrieved records match the requested subject and or session id as well, to prevent false positives.
func GetLogoutRecords(suse SuSe, mode LogoutMode, store microstore.Store) ([]*microstore.Record, error) {
func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, error) {
// get subject.session mode
mode := getLogoutMode(suse)
var key string
var opts []microstore.ReadOption
switch mode {
case LogoutModeSubject:
case logoutModeSubject:
// the dot at the end prevents prefix exploration in the cache,
// so only keys that start with 'subject.*' will be returned, but not 'sub*'.
key = suse.encodedSubject + "."
opts = append(opts, microstore.ReadPrefix())
case LogoutModeSession:
case logoutModeSession:
// the dot at the beginning prevents sufix exploration in the cache,
// so only keys that end with '*.session' will be returned, but not '*sion'.
key = "." + suse.encodedSession
@@ -156,7 +159,7 @@ func GetLogoutRecords(suse SuSe, mode LogoutMode, store microstore.Store) ([]*mi
return nil, microstore.ErrNotFound
}
if mode == LogoutModeSession && len(records) > 1 {
if mode == logoutModeSession && len(records) > 1 {
return nil, errors.Join(errors.New("multiple session records found"), ErrSuspiciousCacheResult)
}
@@ -171,10 +174,10 @@ func GetLogoutRecords(suse SuSe, mode LogoutMode, store microstore.Store) ([]*mi
switch {
// in subject mode, the subject must match, but the session id can be different
case mode == LogoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject:
case mode == logoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject:
continue
// in session mode, the session id must match, but the subject can be different
case mode == LogoutModeSession && suse.encodedSession == recordSuSe.encodedSession:
case mode == logoutModeSession && suse.encodedSession == recordSuSe.encodedSession:
continue
}

View File

@@ -137,33 +137,33 @@ func TestGetLogoutMode(t *testing.T) {
tests := []struct {
name string
suSe SuSe
want LogoutMode
want logoutMode
}{
{
name: "key variation: '.session'",
suSe: mustNewSuSe(t, "", "session"),
want: LogoutModeSession,
want: logoutModeSession,
},
{
name: "key variation: 'subject.session'",
suSe: mustNewSuSe(t, "subject", "session"),
want: LogoutModeSession,
want: logoutModeSession,
},
{
name: "key variation: 'subject.'",
suSe: mustNewSuSe(t, "subject", ""),
want: LogoutModeSubject,
want: logoutModeSubject,
},
{
name: "key variation: 'empty'",
suSe: SuSe{},
want: LogoutModeUndefined,
want: logoutModeUndefined,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mode := GetLogoutMode(tt.suSe)
mode := getLogoutMode(tt.suSe)
require.Equal(t, tt.want, mode)
})
}
@@ -197,34 +197,13 @@ func TestGetLogoutRecords(t *testing.T) {
tests := []struct {
name string
suSe SuSe
mode LogoutMode
store func(t *testing.T) store.Store
wantRecords []*store.Record
wantErrs []error
}{
{
name: "fails if mode is unknown",
suSe: mustNewSuSe(t, "", "session-a"),
mode: LogoutModeUndefined,
store: func(t *testing.T) store.Store {
return sessionStore
},
wantRecords: []*store.Record{},
wantErrs: []error{ErrSuspiciousCacheResult},
},
{
name: "fails if mode is any random int",
suSe: mustNewSuSe(t, "", "session-a"),
mode: 999,
store: func(t *testing.T) store.Store {
return sessionStore
},
wantRecords: []*store.Record{},
wantErrs: []error{ErrSuspiciousCacheResult}},
{
name: "fails if multiple session records are found",
suSe: mustNewSuSe(t, "", "session-a"),
mode: LogoutModeSession,
store: func(t *testing.T) store.Store {
s := mocks.NewStore(t)
s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{
@@ -238,7 +217,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "fails if the record key is not ok",
suSe: mustNewSuSe(t, "", "session-a"),
mode: LogoutModeSession,
store: func(t *testing.T) store.Store {
s := mocks.NewStore(t)
s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{
@@ -252,7 +230,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "fails if the session does not match the retrieved record",
suSe: mustNewSuSe(t, "", "session-a"),
mode: LogoutModeSession,
store: func(t *testing.T) store.Store {
s := mocks.NewStore(t)
s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{
@@ -265,7 +242,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "fails if the subject does not match the retrieved record",
suSe: mustNewSuSe(t, "subject-a", ""),
mode: LogoutModeSubject,
store: func(t *testing.T) store.Store {
s := mocks.NewStore(t)
s.EXPECT().Read(mock.Anything, mock.Anything).Return([]*store.Record{
@@ -279,7 +255,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "key variation: 'session-a'",
suSe: mustNewSuSe(t, "", "session-a"),
mode: LogoutModeSession,
store: func(*testing.T) store.Store {
return sessionStore
},
@@ -288,7 +263,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "key variation: 'session-b'",
suSe: mustNewSuSe(t, "", "session-b"),
mode: LogoutModeSession,
store: func(*testing.T) store.Store {
return sessionStore
},
@@ -297,7 +271,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "key variation: 'session-c'",
suSe: mustNewSuSe(t, "", "session-c"),
mode: LogoutModeSession,
store: func(*testing.T) store.Store {
return sessionStore
},
@@ -306,7 +279,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "key variation: 'ession-c'",
suSe: mustNewSuSe(t, "", "ession-c"),
mode: LogoutModeSession,
store: func(*testing.T) store.Store {
return sessionStore
},
@@ -316,7 +288,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "key variation: 'subject-a'",
suSe: mustNewSuSe(t, "subject-a", ""),
mode: LogoutModeSubject,
store: func(*testing.T) store.Store {
return sessionStore
},
@@ -325,7 +296,6 @@ func TestGetLogoutRecords(t *testing.T) {
{
name: "key variation: 'subject-'",
suSe: mustNewSuSe(t, "subject-", ""),
mode: LogoutModeSubject,
store: func(*testing.T) store.Store {
return sessionStore
},
@@ -336,7 +306,7 @@ func TestGetLogoutRecords(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
records, err := GetLogoutRecords(tt.suSe, tt.mode, tt.store(t))
records, err := GetLogoutRecords(tt.suSe, tt.store(t))
for _, wantErr := range tt.wantErrs {
require.ErrorIs(t, err, wantErr)
}