diff --git a/proxy/pkg/middleware/oidc_auth.go b/proxy/pkg/middleware/oidc_auth.go index de44511e7d..f7aaaacdc8 100644 --- a/proxy/pkg/middleware/oidc_auth.go +++ b/proxy/pkg/middleware/oidc_auth.go @@ -26,21 +26,46 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler { options := newOptions(optionSetters...) tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize)) + h := oidcAuth{ + logger: options.Logger, + providerFunc: options.OIDCProviderFunc, + httpClient: options.HTTPClient, + oidcIss: options.OIDCIss, + tokenCache: &tokenCache, + tokenCacheTTL: options.UserinfoCacheTTL, + } + return func(next http.Handler) http.Handler { - return &oidcAuth{ - next: next, - logger: options.Logger, - providerFunc: options.OIDCProviderFunc, - httpClient: options.HTTPClient, - oidcIss: options.OIDCIss, - tokenCache: &tokenCache, - tokenCacheTTL: options.UserinfoCacheTTL, - } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if !h.shouldServe(req) { + next.ServeHTTP(w, req) + return + } + + if h.getProvider() == nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") + + claims, status := h.getClaims(token, req) + if status != 0 { + w.WriteHeader(status) + return + } + + // inject claims to the request context for the account_uuid middleware. + req = req.WithContext(oidc.NewContext(req.Context(), &claims)) + + // store claims in context + // uses the original context, not the one with probably reduced security + next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims))) + }) } } type oidcAuth struct { - next http.Handler logger log.Logger provider OIDCProvider providerFunc func() (OIDCProvider, error) @@ -50,34 +75,6 @@ type oidcAuth struct { tokenCacheTTL time.Duration } -func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) { - - if !m.shouldServe(req) { - m.next.ServeHTTP(w, req) - return - } - - if m.getProvider() == nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") - - claims, status := m.getClaims(token, req) - if status != 0 { - w.WriteHeader(status) - return - } - - // inject claims to the request context for the account_uuid middleware. - req = req.WithContext(oidc.NewContext(req.Context(), &claims)) - - // store claims in context - // uses the original context, not the one with probably reduced security - m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims))) -} - func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.StandardClaims, status int) { hit := m.tokenCache.Get(token) if hit == nil { @@ -164,7 +161,7 @@ func (m oidcAuth) shouldServe(req *http.Request) bool { return strings.HasPrefix(header, "Bearer ") } -func (m oidcAuth) getProvider() OIDCProvider { +func (m *oidcAuth) getProvider() OIDCProvider { if m.provider == nil { // Lazily initialize a provider