diff --git a/ocis-pkg/oidc/client.go b/ocis-pkg/oidc/client.go index aefca8e02d..6073390703 100644 --- a/ocis-pkg/oidc/client.go +++ b/ocis-pkg/oidc/client.go @@ -10,15 +10,20 @@ import ( "net/http" "strings" "sync" + "time" + "github.com/MicahParks/keyfunc" gOidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v4" "github.com/owncloud/ocis/v2/ocis-pkg/log" + "github.com/owncloud/ocis/v2/services/proxy/pkg/config" "golang.org/x/oauth2" ) // OIDCProvider used to mock the oidc provider during tests type OIDCProvider interface { UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error) + VerifyAccessToken(ctx context.Context, token string) (jwt.RegisteredClaims, []string, error) } // KeySet is a set of publc JSON Web Keys that can be used to validate the signature @@ -37,16 +42,22 @@ type KeySet interface { } type oidcClient struct { - issuer string - provider *ProviderMetadata - providerLock *sync.Mutex - skipIssuerValidation bool - remoteKeySet KeySet - algorithms []string // Logger to use for logging, must be set Logger log.Logger - client *http.Client + issuer string + provider *ProviderMetadata + providerLock *sync.Mutex + skipIssuerValidation bool + accessTokenVerifyMethod string + remoteKeySet KeySet // TODO replace usage with keyfunc? + algorithms []string + + JWKSOptions config.JWKS + JWKS *keyfunc.JWKS + jwksLock *sync.Mutex + + httpClient *http.Client } // supportedAlgorithms is a list of algorithms explicitly supported by this @@ -69,10 +80,13 @@ func NewOIDCClient(opts ...Option) OIDCProvider { options := newOptions(opts...) return &oidcClient{ - Logger: options.Logger, - issuer: options.OidcIssuer, - client: options.HTTPClient, - providerLock: &sync.Mutex{}, + Logger: options.Logger, + issuer: options.OidcIssuer, + httpClient: options.HTTPClient, + accessTokenVerifyMethod: options.AccessTokenVerifyMethod, + JWKSOptions: options.JWKSOptions, // TODO I don't like that we pass down config options ... + providerLock: &sync.Mutex{}, + jwksLock: &sync.Mutex{}, } } @@ -85,7 +99,7 @@ func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) err if err != nil { return err } - resp, err := c.client.Do(req.WithContext(ctx)) + resp, err := c.httpClient.Do(req.WithContext(ctx)) if err != nil { return err } @@ -122,6 +136,32 @@ func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) err return nil } +func (c *oidcClient) getKeyfunc() *keyfunc.JWKS { + c.jwksLock.Lock() + defer c.jwksLock.Unlock() + if c.JWKS == nil { + var err error + c.Logger.Debug().Str("jwks", c.provider.JwksURI).Msg("discovered jwks endpoint") + options := keyfunc.Options{ + Client: c.httpClient, + RefreshErrorHandler: func(err error) { + c.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc") + }, + RefreshInterval: time.Minute * time.Duration(c.JWKSOptions.RefreshInterval), + RefreshRateLimit: time.Second * time.Duration(c.JWKSOptions.RefreshRateLimit), + RefreshTimeout: time.Second * time.Duration(c.JWKSOptions.RefreshTimeout), + RefreshUnknownKID: c.JWKSOptions.RefreshUnknownKID, + } + c.JWKS, err = keyfunc.Get(c.provider.JwksURI, options) + if err != nil { + c.JWKS = nil + c.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.") + return nil + } + } + return c.JWKS +} + type stringAsBool bool func (sb *stringAsBool) UnmarshalJSON(b []byte) error { @@ -186,7 +226,7 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc } token.SetAuthHeader(req) - resp, err := c.client.Do(req.WithContext(ctx)) + resp, err := c.httpClient.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -222,6 +262,59 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc }, nil } +func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (jwt.RegisteredClaims, []string, error) { + var mapClaims []string + if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil { + return jwt.RegisteredClaims{}, mapClaims, err + } + switch c.accessTokenVerifyMethod { + case config.AccessTokenVerificationJWT: + return c.verifyAccessTokenJWT(token) + case config.AccessTokenVerificationNone: + c.Logger.Debug().Msg("Access Token verification disabled") + return jwt.RegisteredClaims{}, mapClaims, nil + default: + c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting") + return jwt.RegisteredClaims{}, mapClaims, errors.New("unknown Access Token Verification method") + } +} + +// verifyAccessTokenJWT tries to parse and verify the access token as a JWT. +func (c *oidcClient) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) { + var claims jwt.RegisteredClaims + var mapClaims []string + jwks := c.getKeyfunc() + if jwks == nil { + return claims, mapClaims, errors.New("Error initializing jwks keyfunc") + } + + _, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc) + if err != nil { + return claims, mapClaims, err + } + _, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) + // TODO: decode mapClaims to sth readable + c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token") + if err != nil { + c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.") + return claims, mapClaims, err + } + c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token") + if err != nil { + c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.") + return claims, mapClaims, err + } + + if !claims.VerifyIssuer(c.issuer, true) { + vErr := jwt.ValidationError{} + vErr.Inner = jwt.ErrTokenInvalidIssuer + vErr.Errors |= jwt.ValidationErrorIssuer + return claims, mapClaims, vErr + } + + return claims, mapClaims, nil +} + func unmarshalResp(r *http.Response, body []byte, v interface{}) error { err := json.Unmarshal(body, &v) if err == nil { diff --git a/ocis-pkg/oidc/options.go b/ocis-pkg/oidc/options.go index 1f7cd64cce..9085fee367 100644 --- a/ocis-pkg/oidc/options.go +++ b/ocis-pkg/oidc/options.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/owncloud/ocis/v2/ocis-pkg/log" + "github.com/owncloud/ocis/v2/services/proxy/pkg/config" ) // Option defines a single option function. @@ -17,6 +18,11 @@ type Options struct { Logger log.Logger // The OpenID Connect Issuer URL OidcIssuer string + // JWKSOptions to use when retrieving keys + JWKSOptions config.JWKS + // AccessTokenVerifyMethod to use when verifying access tokens + // TODO pass a function or interface to verify? an AccessTokenVerifier? + AccessTokenVerifyMethod string } // newOptions initializes the available default options. @@ -44,8 +50,19 @@ func WithLogger(val log.Logger) Option { } } +// WithAccessTokenVerifyMethod provides a function to set the accessTokenVerifyMethod option. +func WithAccessTokenVerifyMethod(val string) Option { + return func(o *Options) { + o.AccessTokenVerifyMethod = val + } +} func WithHTTPClient(val *http.Client) Option { return func(o *Options) { o.HTTPClient = val } } +func WithJWKSOptions(val config.JWKS) Option { + return func(o *Options) { + o.JWKSOptions = val + } +} diff --git a/services/proxy/pkg/command/server.go b/services/proxy/pkg/command/server.go index cc3ea1082f..86ae967cfd 100644 --- a/services/proxy/pkg/command/server.go +++ b/services/proxy/pkg/command/server.go @@ -291,12 +291,12 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config, middleware.DefaultAccessTokenTTL(cfg.OIDC.UserinfoCache.TTL), middleware.HTTPClient(oidcHTTPClient), middleware.OIDCIss(cfg.OIDC.Issuer), - middleware.JWKSOptions(cfg.OIDC.JWKS), - middleware.AccessTokenVerifyMethod(cfg.OIDC.AccessTokenVerifyMethod), middleware.OIDCClient(oidc.NewOIDCClient( + oidc.WithAccessTokenVerifyMethod(cfg.OIDC.AccessTokenVerifyMethod), oidc.WithLogger(logger), oidc.WithHTTPClient(oidcHTTPClient), oidc.WithOidcIssuer(cfg.OIDC.Issuer), + oidc.WithJWKSOptions(cfg.OIDC.JWKS), )), )) authenticators = append(authenticators, middleware.PublicShareAuthenticator{ diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index a33d24cdb0..95c7a1528b 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -6,14 +6,11 @@ import ( "fmt" "net/http" "strings" - "sync" "time" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" - "github.com/owncloud/ocis/v2/services/proxy/pkg/config" - "github.com/MicahParks/keyfunc" "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" "github.com/shamaton/msgpack/v2" @@ -39,9 +36,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator { HTTPClient: options.HTTPClient, OIDCIss: options.OIDCIss, oidcClient: options.OIDCClient, - JWKSOptions: options.JWKS, AccessTokenVerifyMethod: options.AccessTokenVerifyMethod, - jwksLock: &sync.Mutex{}, } } @@ -55,16 +50,12 @@ type OIDCAuthenticator struct { DefaultTokenCacheTTL time.Duration oidcClient oidc.OIDCProvider AccessTokenVerifyMethod string - JWKSOptions config.JWKS - - jwksLock *sync.Mutex - JWKS *keyfunc.JWKS } func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) { var claims map[string]interface{} - // usea 64 bytes long hash to have 256-bit collision resistance. + // use a 64 bytes long hash to have 256-bit collision resistance. hash := make([]byte, 64) sha3.ShakeSum256(hash, []byte(token)) encodedHash := base64.URLEncoding.EncodeToString(hash) @@ -80,9 +71,10 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri } m.Logger.Error().Err(err).Msg("could not unmarshal userinfo") } + // TODO: use mClaims - aClaims, mClaims, err := m.verifyAccessToken(token) - //fmt.Println(mClaims) + aClaims, mClaims, err := m.oidcClient.VerifyAccessToken(req.Context(), token) + vals := make([]string, len(mClaims)) for k, v := range mClaims { s, _ := base64.StdEncoding.DecodeString(v) @@ -140,57 +132,6 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri return claims, nil } -// TODO: update jwt lib to have access to session id, or extract the session id and return it -func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, []string, error) { - var mapClaims []string - switch m.AccessTokenVerifyMethod { - case config.AccessTokenVerificationJWT: - return m.verifyAccessTokenJWT(token) - case config.AccessTokenVerificationNone: - m.Logger.Debug().Msg("Access Token verification disabled") - return jwt.RegisteredClaims{}, mapClaims, nil - default: - m.Logger.Error().Str("access_token_verify_method", m.AccessTokenVerifyMethod).Msg("Unknown Access Token verification setting") - return jwt.RegisteredClaims{}, mapClaims, errors.New("Unknown Access Token Verification method") - } -} - -// verifyAccessTokenJWT tries to parse and verify the access token as a JWT. -func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) { - var claims jwt.RegisteredClaims - var mapClaims []string - jwks := m.getKeyfunc() - if jwks == nil { - return claims, mapClaims, errors.New("Error initializing jwks keyfunc") - } - - _, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc) - if err != nil { - return claims, mapClaims, err - } - _, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) - // TODO: decode mapClaims to sth readable - m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token") - if err != nil { - m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.") - return claims, mapClaims, err - } - m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token") - if err != nil { - m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.") - return claims, mapClaims, err - } - - if !claims.VerifyIssuer(m.OIDCIss, true) { - vErr := jwt.ValidationError{} - vErr.Inner = jwt.ErrTokenInvalidIssuer - vErr.Errors |= jwt.ValidationErrorIssuer - return claims, mapClaims, vErr - } - - return claims, mapClaims, nil -} - // extractExpiration tries to extract the expriration time from the access token // If the access token does not have an exp claim it will fallback to the configured // default expiration @@ -212,36 +153,6 @@ func (m OIDCAuthenticator) shouldServe(req *http.Request) bool { return strings.HasPrefix(header, _bearerPrefix) } -func (m *OIDCAuthenticator) getKeyfunc() *keyfunc.JWKS { - m.jwksLock.Lock() - defer m.jwksLock.Unlock() - if m.JWKS == nil { - oidcMetadata, err := oidc.GetIDPMetadata(m.Logger, m.HTTPClient, m.OIDCIss) - if err != nil { - m.Logger.Error().Err(err).Msg("failed to decode provider openid-configuration") - return nil - } - m.Logger.Debug().Str("jwks", oidcMetadata.JwksURI).Msg("discovered jwks endpoint") - options := keyfunc.Options{ - Client: m.HTTPClient, - RefreshErrorHandler: func(err error) { - m.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc") - }, - RefreshInterval: time.Minute * time.Duration(m.JWKSOptions.RefreshInterval), - RefreshRateLimit: time.Second * time.Duration(m.JWKSOptions.RefreshRateLimit), - RefreshTimeout: time.Second * time.Duration(m.JWKSOptions.RefreshTimeout), - RefreshUnknownKID: m.JWKSOptions.RefreshUnknownKID, - } - m.JWKS, err = keyfunc.Get(oidcMetadata.JwksURI, options) - if err != nil { - m.JWKS = nil - m.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.") - return nil - } - } - return m.JWKS -} - // Authenticate implements the authenticator interface to authenticate requests via oidc auth. func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) { // there is no bearer token on the request, @@ -251,11 +162,6 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) // implement an early return here for paths we can't authenticate in this authenticator. return nil, false } - - // Force init of jwks keyfunc if needed (contacts the .well-known and jwks endpoints on first call) - if m.AccessTokenVerifyMethod == config.AccessTokenVerificationJWT && m.getKeyfunc() == nil { - return nil, false - } token := strings.TrimPrefix(r.Header.Get(_headerAuthorization), _bearerPrefix) claims, err := m.getClaims(token, r) diff --git a/services/proxy/pkg/middleware/options.go b/services/proxy/pkg/middleware/options.go index bc030f04e2..2b144a9e7e 100644 --- a/services/proxy/pkg/middleware/options.go +++ b/services/proxy/pkg/middleware/options.go @@ -219,13 +219,6 @@ func AccessTokenVerifyMethod(method string) Option { } } -// JWKSOptions sets the options for fetching the JWKS from the IDP -func JWKSOptions(jo config.JWKS) Option { - return func(o *Options) { - o.JWKS = jo - } -} - // RoleQuotas sets the role quota mapping setting func RoleQuotas(roleQuotas map[string]uint64) Option { return func(o *Options) {