From cb8934081f3efe8410432f63b0b878049a437627 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Thu, 29 Aug 2024 10:17:07 +0200 Subject: [PATCH] proxy(oidc): Emit a UserSignedIn event on new session Every time the OIDC middleware sees a new access token (i.e when it needs to update the userinfo cache) we consider that as a new login. In this case the middleware add a new flag to the context, which is then used by the accountresolver middleware to publish a UserSignedIn event. The event needs to be sent by the accountresolver middleware, because only at that point we know the user id of the user that just logged in. (It would probably makes sense to merge the auth and account middleware into a single component to avoid passing flags around via context) --- ocis-pkg/oidc/context.go | 14 ++++++++++ services/proxy/pkg/command/server.go | 9 ++++-- .../proxy/pkg/middleware/account_resolver.go | 15 ++++++++++ services/proxy/pkg/middleware/oidc_auth.go | 28 +++++++++++++------ services/proxy/pkg/middleware/options.go | 11 +++++++- 5 files changed, 64 insertions(+), 13 deletions(-) diff --git a/ocis-pkg/oidc/context.go b/ocis-pkg/oidc/context.go index 1f4bf2d650..fbb55cdf1f 100644 --- a/ocis-pkg/oidc/context.go +++ b/ocis-pkg/oidc/context.go @@ -5,6 +5,9 @@ import "context" // contextKey is the key for oidc claims in a context type contextKey struct{} +// newSessionFlagKey is the key for the new session flag in a context +type newSessionFlagKey struct{} + // NewContext makes a new context that contains the OpenID connect claims in a map. func NewContext(parent context.Context, c map[string]interface{}) context.Context { return context.WithValue(parent, contextKey{}, c) @@ -15,3 +18,14 @@ func FromContext(ctx context.Context) map[string]interface{} { s, _ := ctx.Value(contextKey{}).(map[string]interface{}) return s } + +// NewContextSessionFlag makes a new context that contains the new session flag. +func NewContextSessionFlag(ctx context.Context, flag bool) context.Context { + return context.WithValue(ctx, newSessionFlagKey{}, flag) +} + +// NewSessionFlagFromContext returns the new session flag stored in a context. +func NewSessionFlagFromContext(ctx context.Context) bool { + s, _ := ctx.Value(newSessionFlagKey{}).(bool) + return s +} diff --git a/services/proxy/pkg/command/server.go b/services/proxy/pkg/command/server.go index 4c2fb9a237..ae17b743a1 100644 --- a/services/proxy/pkg/command/server.go +++ b/services/proxy/pkg/command/server.go @@ -182,7 +182,7 @@ func Server(cfg *config.Config) *cli.Command { } { - middlewares := loadMiddlewares(logger, cfg, userInfoCache, signingKeyStore, traceProvider, *m, userProvider, gatewaySelector, serviceSelector) + middlewares := loadMiddlewares(logger, cfg, userInfoCache, signingKeyStore, traceProvider, *m, userProvider, publisher, gatewaySelector, serviceSelector) server, err := proxyHTTP.Server( proxyHTTP.Handler(lh.Handler()), @@ -236,8 +236,10 @@ func Server(cfg *config.Config) *cli.Command { } func loadMiddlewares(logger log.Logger, cfg *config.Config, - userInfoCache, signingKeyStore microstore.Store, traceProvider trace.TracerProvider, metrics metrics.Metrics, - userProvider backend.UserBackend, gatewaySelector pool.Selectable[gateway.GatewayAPIClient], serviceSelector selector.Selector) alice.Chain { + userInfoCache, signingKeyStore microstore.Store, + traceProvider trace.TracerProvider, metrics metrics.Metrics, + userProvider backend.UserBackend, publisher events.Publisher, + gatewaySelector pool.Selectable[gateway.GatewayAPIClient], serviceSelector selector.Selector) alice.Chain { rolesClient := settingssvc.NewRoleService("com.owncloud.api.settings", cfg.GrpcClient) policiesProviderClient := policiessvc.NewPoliciesProviderService("com.owncloud.api.policies", cfg.GrpcClient) @@ -354,6 +356,7 @@ func loadMiddlewares(logger log.Logger, cfg *config.Config, middleware.UserOIDCClaim(cfg.UserOIDCClaim), middleware.UserCS3Claim(cfg.UserCS3Claim), middleware.AutoprovisionAccounts(cfg.AutoprovisionAccounts), + middleware.EventsPublisher(publisher), ), middleware.SelectorCookie( middleware.Logger(logger), diff --git a/services/proxy/pkg/middleware/account_resolver.go b/services/proxy/pkg/middleware/account_resolver.go index f560af347e..422d067ed8 100644 --- a/services/proxy/pkg/middleware/account_resolver.go +++ b/services/proxy/pkg/middleware/account_resolver.go @@ -11,6 +11,8 @@ import ( "github.com/owncloud/ocis/v2/services/proxy/pkg/userroles" revactx "github.com/cs3org/reva/v2/pkg/ctx" + "github.com/cs3org/reva/v2/pkg/events" + "github.com/cs3org/reva/v2/pkg/utils" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" ) @@ -37,6 +39,7 @@ func AccountResolver(optionSetters ...Option) func(next http.Handler) http.Handl userRoleAssigner: options.UserRoleAssigner, autoProvisionAccounts: options.AutoprovisionAccounts, lastGroupSyncCache: lastGroupSyncCache, + eventsPublisher: options.EventsPublisher, } } } @@ -53,6 +56,7 @@ type accountResolver struct { // memberships was done for a specific user. This is used to trigger a sync // with every single request. lastGroupSyncCache *ttlcache.Cache[string, struct{}] + eventsPublisher events.Publisher } func readUserIDClaim(path string, claims map[string]interface{}) (string, error) { @@ -172,6 +176,17 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } + // If this is a new session, publish user login event + if newSession := oidc.NewSessionFlagFromContext(ctx); newSession && m.eventsPublisher != nil { + event := events.UserSignedIn{ + Executant: user.Id, + Timestamp: utils.TimeToTS(time.Now()), + } + if err := events.Publish(req.Context(), m.eventsPublisher, event); err != nil { + m.logger.Error().Err(err).Msg("could not publish user signin event.") + } + } + // add user to context for selectors ctx = revactx.ContextSetUser(ctx, user) req = req.WithContext(ctx) diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index 47b9c424d8..ed29111c02 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -53,7 +53,7 @@ type OIDCAuthenticator struct { TimeFunc func() time.Time } -func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) { +func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, bool, error) { var claims map[string]interface{} // use a 64 bytes long hash to have 256-bit collision resistance. @@ -69,16 +69,16 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil { m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo") if ok := verifyExpiresAt(claims, m.TimeFunc()); !ok { - return nil, jwt.ErrTokenExpired + return nil, false, jwt.ErrTokenExpired } - return claims, nil + return claims, false, nil } m.Logger.Error().Err(err).Msg("could not unmarshal userinfo") } aClaims, claims, err := m.oidcClient.VerifyAccessToken(req.Context(), token) if err != nil { - return nil, errors.Wrap(err, "failed to verify access token") + return nil, false, errors.Wrap(err, "failed to verify access token") } if !m.skipUserInfo { @@ -91,10 +91,10 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri oauth2.StaticTokenSource(oauth2Token), ) if err != nil { - return nil, errors.Wrap(err, "failed to get userinfo") + return nil, false, errors.Wrap(err, "failed to get userinfo") } if err := userInfo.Claims(&claims); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal userinfo claims") + return nil, false, errors.Wrap(err, "failed to unmarshal userinfo claims") } } @@ -128,8 +128,12 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri } }() + // If we get here this was a new login (or a renewal of the token) + // add a flag about that to the claims, to be able to distinguish + // it in the accountresolver middleware + m.Logger.Debug().Interface("claims", claims).Msg("extracted claims") - return claims, nil + return claims, true, nil } // extractExpiration tries to extract the expriration time from the access token @@ -180,7 +184,7 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) return nil, false } - claims, err := m.getClaims(token, r) + claims, newSession, err := m.getClaims(token, r) if err != nil { host, port, _ := net.SplitHostPort(r.RemoteAddr) m.Logger.Error(). @@ -198,5 +202,11 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) Str("authenticator", "oidc"). Str("path", r.URL.Path). Msg("successfully authenticated request") - return r.WithContext(oidc.NewContext(r.Context(), claims)), true + + ctx := r.Context() + if newSession { + ctx = oidc.NewContextSessionFlag(ctx, true) + } + + return r.WithContext(oidc.NewContext(ctx, claims)), true } diff --git a/services/proxy/pkg/middleware/options.go b/services/proxy/pkg/middleware/options.go index 917f5ca311..90b60c45b4 100644 --- a/services/proxy/pkg/middleware/options.go +++ b/services/proxy/pkg/middleware/options.go @@ -5,6 +5,7 @@ import ( "time" gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1" + "github.com/cs3org/reva/v2/pkg/events" "github.com/cs3org/reva/v2/pkg/rgrpc/todo/pool" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" @@ -69,7 +70,8 @@ type Options struct { // TraceProvider sets the tracing provider. TraceProvider trace.TracerProvider // SkipUserInfo prevents the oidc middleware from querying the userinfo endpoint and read any claims directly from the access token instead - SkipUserInfo bool + SkipUserInfo bool + EventsPublisher events.Publisher } // newOptions initializes the available default options. @@ -236,3 +238,10 @@ func SkipUserInfo(val bool) Option { o.SkipUserInfo = val } } + +// EventsPublisher sets the events publisher. +func EventsPublisher(ep events.Publisher) Option { + return func(o *Options) { + o.EventsPublisher = ep + } +}