From 00afda3c04862ef390ee8a1da26d4e78dcb8d15a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Friedrich=20Dreyer?= Date: Thu, 18 Jun 2020 21:40:06 +0200 Subject: [PATCH] test middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jörn Friedrich Dreyer --- pkg/middleware/account_uuid.go | 8 ++++---- pkg/middleware/account_uuid_test.go | 26 ++++++++++++++++++++++++-- pkg/middleware/openidconnect.go | 5 +---- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/pkg/middleware/account_uuid.go b/pkg/middleware/account_uuid.go index 88ebbbb9a1..c4d53255f8 100644 --- a/pkg/middleware/account_uuid.go +++ b/pkg/middleware/account_uuid.go @@ -10,7 +10,7 @@ import ( "github.com/cs3org/reva/pkg/token/manager/jwt" acc "github.com/owncloud/ocis-accounts/pkg/proto/v0" "github.com/owncloud/ocis-pkg/v2/log" - ocisoidc "github.com/owncloud/ocis-pkg/v2/oidc" + oidc "github.com/owncloud/ocis-pkg/v2/oidc" "github.com/owncloud/ocis-proxy/pkg/config" ) @@ -56,7 +56,7 @@ func newAccountUUIDOptions(opts ...AccountMiddlewareOption) AccountMiddlewareOpt return opt } -func getAccount(l log.Logger, claims ocisoidc.StandardClaims, ac acc.AccountsService) (account *acc.Account, status int) { +func getAccount(l log.Logger, claims *oidc.StandardClaims, ac acc.AccountsService) (account *acc.Account, status int) { entry, err := svcCache.Get(AccountsKey, claims.Email) if err != nil { l.Debug().Msgf("No cache entry for %v", claims.Email) @@ -121,8 +121,8 @@ func AccountUUID(opts ...AccountMiddlewareOption) func(next http.Handler) http.H return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := opt.Logger - claims, ok := r.Context().Value(ClaimsKey).(ocisoidc.StandardClaims) - if !ok { + claims := oidc.FromContext(r.Context()) + if claims == nil { next.ServeHTTP(w, r) return } diff --git a/pkg/middleware/account_uuid_test.go b/pkg/middleware/account_uuid_test.go index 118599922e..e86d9abc64 100644 --- a/pkg/middleware/account_uuid_test.go +++ b/pkg/middleware/account_uuid_test.go @@ -4,28 +4,50 @@ import ( "context" "fmt" "net/http" + "net/http/httptest" "testing" "github.com/micro/go-micro/v2/client" "github.com/owncloud/ocis-accounts/pkg/proto/v0" "github.com/owncloud/ocis-pkg/v2/log" "github.com/owncloud/ocis-pkg/v2/oidc" + "github.com/owncloud/ocis-proxy/pkg/config" ) // TODO testing the getAccount method should inject a cache func TestGetAccountSuccess(t *testing.T) { svcCache.Invalidate(AccountsKey, "success") - if _, status := getAccount(log.NewLogger(), oidc.StandardClaims{Email: "success"}, mockAccSvc(false)); status != 0 { + if _, status := getAccount(log.NewLogger(), &oidc.StandardClaims{Email: "success"}, mockAccSvc(false)); status != 0 { t.Errorf("expected an account") } } func TestGetAccountInternalError(t *testing.T) { svcCache.Invalidate(AccountsKey, "failure") - if _, status := getAccount(log.NewLogger(), oidc.StandardClaims{Email: "failure"}, mockAccSvc(true)); status != http.StatusInternalServerError { + if _, status := getAccount(log.NewLogger(), &oidc.StandardClaims{Email: "failure"}, mockAccSvc(true)); status != http.StatusInternalServerError { t.Errorf("expected an internal server error") } } +func TestAccountUUIDHandler(t *testing.T) { + svcCache.Invalidate(AccountsKey, "success") + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + m := AccountUUID( + Logger(log.NewLogger()), + TokenManagerConfig(config.TokenManager{JWTSecret: "secret"}), + AccountsClient(mockAccSvc(false)), + )(next) + + r := httptest.NewRequest(http.MethodGet, "http://www.example.com", nil) + w := httptest.NewRecorder() + ctx := oidc.NewContext(r.Context(), &oidc.StandardClaims{Email: "success"}) + r = r.WithContext(ctx) + m.ServeHTTP(w, r) + + if r.Header.Get("x-access-token") == "" { + t.Errorf("expected a token") + } +} + func mockAccSvc(retErr bool) proto.AccountsService { if retErr { return &proto.MockAccountsService{ diff --git a/pkg/middleware/openidconnect.go b/pkg/middleware/openidconnect.go index c62c757310..5280f59a99 100644 --- a/pkg/middleware/openidconnect.go +++ b/pkg/middleware/openidconnect.go @@ -22,9 +22,6 @@ var ( svcCache = cache.NewCache( cache.Size(256), ) - - // ClaimsKey works as a context key for user claims - ClaimsKey interface{} = "claims" ) // newOIDCOptions initializes the available default options. @@ -113,7 +110,7 @@ func OpenIDConnect(opts ...ocisoidc.Option) func(next http.Handler) http.Handler } // inject claims to the request context for the account_uuid middleware. - ctxWithClaims := context.WithValue(r.Context(), ClaimsKey, claims) + ctxWithClaims := ocisoidc.NewContext(r.Context(), &claims) r = r.WithContext(ctxWithClaims) opt.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")