From 88e8bb1b7218e90d81ba449e4c9466ccbe9413d5 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Tue, 7 Mar 2023 17:30:07 +0100 Subject: [PATCH] account_resolver: Handle user roles separately from user lookup This removes the "withRoles" flag from the GetUserByClaims lookup and move the functionality into a separate method. This should make the code a bit more readable in preparation for maintaining the RoleAssignments from OIDC claims. --- services/proxy/pkg/command/server.go | 1 - .../proxy/pkg/middleware/account_resolver.go | 14 +++++-- .../pkg/middleware/account_resolver_test.go | 1 + .../proxy/pkg/middleware/signed_url_auth.go | 12 +++++- services/proxy/pkg/user/backend/backend.go | 4 +- services/proxy/pkg/user/backend/cs3.go | 21 ++++------ .../pkg/user/backend/mocks/UserBackend.go | 42 +++++++++++++------ 7 files changed, 64 insertions(+), 31 deletions(-) diff --git a/services/proxy/pkg/command/server.go b/services/proxy/pkg/command/server.go index 05a5b718ed..aaddcf997a 100644 --- a/services/proxy/pkg/command/server.go +++ b/services/proxy/pkg/command/server.go @@ -236,7 +236,6 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config) ), middleware.SelectorCookie( middleware.Logger(logger), - middleware.UserProvider(userProvider), middleware.PolicySelectorConfig(*cfg.PolicySelector), ), middleware.Policies(logger, cfg.PoliciesMiddleware.Query), diff --git a/services/proxy/pkg/middleware/account_resolver.go b/services/proxy/pkg/middleware/account_resolver.go index 6ccdf987d1..581b46a120 100644 --- a/services/proxy/pkg/middleware/account_resolver.go +++ b/services/proxy/pkg/middleware/account_resolver.go @@ -62,7 +62,7 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - user, token, err = m.userProvider.GetUserByClaims(req.Context(), m.userCS3Claim, value, true) + user, token, err = m.userProvider.GetUserByClaims(req.Context(), m.userCS3Claim, value) if errors.Is(err, backend.ErrAccountNotFound) { m.logger.Debug().Str("claim", m.userOIDCClaim).Str("value", value).Msg("User by claim not found") @@ -76,7 +76,7 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) { if err != nil { m.logger.Error().Err(err).Msg("Autoprovisioning user failed") } - user, token, err = m.userProvider.GetUserByClaims(req.Context(), "userid", user.Id.OpaqueId, true) + user, token, err = m.userProvider.GetUserByClaims(req.Context(), "userid", user.Id.OpaqueId) if err != nil { m.logger.Error().Err(err).Str("userid", user.Id.OpaqueId).Msg("Error getting token for autoprovisioned user") } @@ -94,6 +94,14 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } + // resolve the user's roles + user, err = m.userProvider.GetUserRoles(ctx, user) + if err != nil { + m.logger.Error().Err(err).Msg("Could not get user roles") + w.WriteHeader(http.StatusInternalServerError) + return + } + // add user to context for selectors ctx = revactx.ContextSetUser(ctx, user) req = req.WithContext(ctx) @@ -101,7 +109,7 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) { m.logger.Debug().Interface("claims", claims).Interface("user", user).Msg("associated claims with user") } else if user != nil { var err error - _, token, err = m.userProvider.GetUserByClaims(req.Context(), "username", user.Username, true) + _, token, err = m.userProvider.GetUserByClaims(req.Context(), "username", user.Username) if errors.Is(err, backend.ErrAccountDisabled) { m.logger.Debug().Interface("user", user).Msg("Disabled") diff --git a/services/proxy/pkg/middleware/account_resolver_test.go b/services/proxy/pkg/middleware/account_resolver_test.go index 5b4f5dbdc9..2d1ae3dedc 100644 --- a/services/proxy/pkg/middleware/account_resolver_test.go +++ b/services/proxy/pkg/middleware/account_resolver_test.go @@ -122,6 +122,7 @@ func newMockAccountResolver(userBackendResult *userv1beta1.User, userBackendErr ub := mocks.UserBackend{} ub.On("GetUserByClaims", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(userBackendResult, token, userBackendErr) + ub.On("GetUserRoles", mock.Anything, mock.Anything).Return(userBackendResult, nil) return AccountResolver( Logger(log.NewLogger()), diff --git a/services/proxy/pkg/middleware/signed_url_auth.go b/services/proxy/pkg/middleware/signed_url_auth.go index 793f6d8adc..787fd6908a 100644 --- a/services/proxy/pkg/middleware/signed_url_auth.go +++ b/services/proxy/pkg/middleware/signed_url_auth.go @@ -192,7 +192,17 @@ func (m SignedURLAuthenticator) Authenticate(r *http.Request) (*http.Request, bo return nil, false } - user, _, err := m.UserProvider.GetUserByClaims(r.Context(), "username", r.URL.Query().Get(_paramOCCredential), true) + user, _, err := m.UserProvider.GetUserByClaims(r.Context(), "username", r.URL.Query().Get(_paramOCCredential)) + if err != nil { + m.Logger.Error(). + Err(err). + Str("authenticator", "signed_url"). + Str("path", r.URL.Path). + Msg("Could not get user by claim") + return nil, false + } + + user, err = m.UserProvider.GetUserRoles(r.Context(), user) if err != nil { m.Logger.Error(). Err(err). diff --git a/services/proxy/pkg/user/backend/backend.go b/services/proxy/pkg/user/backend/backend.go index 909176fdd2..6cc911c49b 100644 --- a/services/proxy/pkg/user/backend/backend.go +++ b/services/proxy/pkg/user/backend/backend.go @@ -25,10 +25,10 @@ var ( // UserBackend allows the proxy to retrieve users from different user-backends (accounts-service, CS3) type UserBackend interface { - GetUserByClaims(ctx context.Context, claim, value string, withRoles bool) (*cs3.User, string, error) + GetUserByClaims(ctx context.Context, claim, value string) (*cs3.User, string, error) + GetUserRoles(ctx context.Context, user *cs3.User) (*cs3.User, error) Authenticate(ctx context.Context, username string, password string) (*cs3.User, string, error) CreateUserFromClaims(ctx context.Context, claims map[string]interface{}) (*cs3.User, error) - GetUserGroups(ctx context.Context, userID string) } // RevaAuthenticator helper interface to mock auth-method from reva gateway-client. diff --git a/services/proxy/pkg/user/backend/cs3.go b/services/proxy/pkg/user/backend/cs3.go index b37cd48133..bf54bbe0e8 100644 --- a/services/proxy/pkg/user/backend/cs3.go +++ b/services/proxy/pkg/user/backend/cs3.go @@ -51,7 +51,7 @@ func NewCS3UserBackend(rs settingssvc.RoleService, ap RevaAuthenticator, machine } } -func (c *cs3backend) GetUserByClaims(ctx context.Context, claim, value string, withRoles bool) (*cs3.User, string, error) { +func (c *cs3backend) GetUserByClaims(ctx context.Context, claim, value string) (*cs3.User, string, error) { res, err := c.authProvider.Authenticate(ctx, &gateway.AuthenticateRequest{ Type: "machine", ClientId: claim + ":" + value, @@ -70,16 +70,17 @@ func (c *cs3backend) GetUserByClaims(ctx context.Context, claim, value string, w user := res.User - if !withRoles { - return user, res.Token, nil - } + return user, res.Token, nil +} +func (c *cs3backend) GetUserRoles(ctx context.Context, user *cs3.User) (*cs3.User, error) { var roleIDs []string if user.Id.Type != cs3.UserType_USER_TYPE_LIGHTWEIGHT { + var err error roleIDs, err = loadRolesIDs(ctx, user.Id.OpaqueId, c.settingsRoleService) if err != nil { c.logger.Error().Err(err).Msgf("Could not load roles") - return nil, "", err + return nil, err } if len(roleIDs) == 0 { @@ -95,7 +96,7 @@ func (c *cs3backend) GetUserByClaims(ctx context.Context, claim, value string, w }) if err != nil { c.logger.Error().Err(err).Msg("Could not add default role") - return nil, "", err + return nil, err } roleIDs = append(roleIDs, settingsService.BundleUUIDRoleUser) } @@ -105,6 +106,7 @@ func (c *cs3backend) GetUserByClaims(ctx context.Context, claim, value string, w enc, err := encodeRoleIDs(roleIDs) if err != nil { c.logger.Error().Err(err).Msg("Could not encode loaded roles") + return nil, err } if user.Opaque == nil { @@ -116,8 +118,7 @@ func (c *cs3backend) GetUserByClaims(ctx context.Context, claim, value string, w } else { user.Opaque.Map["roles"] = enc } - - return user, res.Token, nil + return user, nil } func (c *cs3backend) Authenticate(ctx context.Context, username string, password string) (*cs3.User, string, error) { @@ -198,10 +199,6 @@ func (c *cs3backend) CreateUserFromClaims(ctx context.Context, claims map[string return &cs3UserCreated, nil } -func (c cs3backend) GetUserGroups(ctx context.Context, userID string) { - panic("implement me") -} - func (c cs3backend) setupLibregraphClient(ctx context.Context, cs3token string) (*libregraph.APIClient, error) { // Use micro registry to resolve next graph service endpoint next, err := c.graphSelector.Select("com.owncloud.graph.graph") diff --git a/services/proxy/pkg/user/backend/mocks/UserBackend.go b/services/proxy/pkg/user/backend/mocks/UserBackend.go index 4669430622..c40cc5e493 100644 --- a/services/proxy/pkg/user/backend/mocks/UserBackend.go +++ b/services/proxy/pkg/user/backend/mocks/UserBackend.go @@ -67,13 +67,13 @@ func (_m *UserBackend) CreateUserFromClaims(ctx context.Context, claims map[stri return r0, r1 } -// GetUserByClaims provides a mock function with given fields: ctx, claim, value, withRoles -func (_m *UserBackend) GetUserByClaims(ctx context.Context, claim string, value string, withRoles bool) (*userv1beta1.User, string, error) { - ret := _m.Called(ctx, claim, value, withRoles) +// GetUserByClaims provides a mock function with given fields: ctx, claim, value +func (_m *UserBackend) GetUserByClaims(ctx context.Context, claim string, value string) (*userv1beta1.User, string, error) { + ret := _m.Called(ctx, claim, value) var r0 *userv1beta1.User - if rf, ok := ret.Get(0).(func(context.Context, string, string, bool) *userv1beta1.User); ok { - r0 = rf(ctx, claim, value, withRoles) + if rf, ok := ret.Get(0).(func(context.Context, string, string) *userv1beta1.User); ok { + r0 = rf(ctx, claim, value) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*userv1beta1.User) @@ -81,15 +81,15 @@ func (_m *UserBackend) GetUserByClaims(ctx context.Context, claim string, value } var r1 string - if rf, ok := ret.Get(1).(func(context.Context, string, string, bool) string); ok { - r1 = rf(ctx, claim, value, withRoles) + if rf, ok := ret.Get(1).(func(context.Context, string, string) string); ok { + r1 = rf(ctx, claim, value) } else { r1 = ret.Get(1).(string) } var r2 error - if rf, ok := ret.Get(2).(func(context.Context, string, string, bool) error); ok { - r2 = rf(ctx, claim, value, withRoles) + if rf, ok := ret.Get(2).(func(context.Context, string, string) error); ok { + r2 = rf(ctx, claim, value) } else { r2 = ret.Error(2) } @@ -97,9 +97,27 @@ func (_m *UserBackend) GetUserByClaims(ctx context.Context, claim string, value return r0, r1, r2 } -// GetUserGroups provides a mock function with given fields: ctx, userID -func (_m *UserBackend) GetUserGroups(ctx context.Context, userID string) { - _m.Called(ctx, userID) +// GetUserRoles provides a mock function with given fields: ctx, user +func (_m *UserBackend) GetUserRoles(ctx context.Context, user *userv1beta1.User) (*userv1beta1.User, error) { + ret := _m.Called(ctx, user) + + var r0 *userv1beta1.User + if rf, ok := ret.Get(0).(func(context.Context, *userv1beta1.User) *userv1beta1.User); ok { + r0 = rf(ctx, user) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*userv1beta1.User) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *userv1beta1.User) error); ok { + r1 = rf(ctx, user) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } type mockConstructorTestingTNewUserBackend interface {