move verify access token code to oidc client

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
Jörn Friedrich Dreyer
2023-04-12 14:11:47 +02:00
committed by Christian Richter
parent 469534b321
commit b608d0b0f9
5 changed files with 129 additions and 120 deletions

View File

@@ -10,15 +10,20 @@ import (
"net/http"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc"
gOidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v4"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
"golang.org/x/oauth2"
)
// OIDCProvider used to mock the oidc provider during tests
type OIDCProvider interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
VerifyAccessToken(ctx context.Context, token string) (jwt.RegisteredClaims, []string, error)
}
// KeySet is a set of publc JSON Web Keys that can be used to validate the signature
@@ -37,16 +42,22 @@ type KeySet interface {
}
type oidcClient struct {
issuer string
provider *ProviderMetadata
providerLock *sync.Mutex
skipIssuerValidation bool
remoteKeySet KeySet
algorithms []string
// Logger to use for logging, must be set
Logger log.Logger
client *http.Client
issuer string
provider *ProviderMetadata
providerLock *sync.Mutex
skipIssuerValidation bool
accessTokenVerifyMethod string
remoteKeySet KeySet // TODO replace usage with keyfunc?
algorithms []string
JWKSOptions config.JWKS
JWKS *keyfunc.JWKS
jwksLock *sync.Mutex
httpClient *http.Client
}
// supportedAlgorithms is a list of algorithms explicitly supported by this
@@ -69,10 +80,13 @@ func NewOIDCClient(opts ...Option) OIDCProvider {
options := newOptions(opts...)
return &oidcClient{
Logger: options.Logger,
issuer: options.OidcIssuer,
client: options.HTTPClient,
providerLock: &sync.Mutex{},
Logger: options.Logger,
issuer: options.OidcIssuer,
httpClient: options.HTTPClient,
accessTokenVerifyMethod: options.AccessTokenVerifyMethod,
JWKSOptions: options.JWKSOptions, // TODO I don't like that we pass down config options ...
providerLock: &sync.Mutex{},
jwksLock: &sync.Mutex{},
}
}
@@ -85,7 +99,7 @@ func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) err
if err != nil {
return err
}
resp, err := c.client.Do(req.WithContext(ctx))
resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return err
}
@@ -122,6 +136,32 @@ func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) err
return nil
}
func (c *oidcClient) getKeyfunc() *keyfunc.JWKS {
c.jwksLock.Lock()
defer c.jwksLock.Unlock()
if c.JWKS == nil {
var err error
c.Logger.Debug().Str("jwks", c.provider.JwksURI).Msg("discovered jwks endpoint")
options := keyfunc.Options{
Client: c.httpClient,
RefreshErrorHandler: func(err error) {
c.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
},
RefreshInterval: time.Minute * time.Duration(c.JWKSOptions.RefreshInterval),
RefreshRateLimit: time.Second * time.Duration(c.JWKSOptions.RefreshRateLimit),
RefreshTimeout: time.Second * time.Duration(c.JWKSOptions.RefreshTimeout),
RefreshUnknownKID: c.JWKSOptions.RefreshUnknownKID,
}
c.JWKS, err = keyfunc.Get(c.provider.JwksURI, options)
if err != nil {
c.JWKS = nil
c.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
return nil
}
}
return c.JWKS
}
type stringAsBool bool
func (sb *stringAsBool) UnmarshalJSON(b []byte) error {
@@ -186,7 +226,7 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc
}
token.SetAuthHeader(req)
resp, err := c.client.Do(req.WithContext(ctx))
resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@@ -222,6 +262,59 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc
}, nil
}
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (jwt.RegisteredClaims, []string, error) {
var mapClaims []string
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return jwt.RegisteredClaims{}, mapClaims, err
}
switch c.accessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return c.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
c.Logger.Debug().Msg("Access Token verification disabled")
return jwt.RegisteredClaims{}, mapClaims, nil
default:
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return jwt.RegisteredClaims{}, mapClaims, errors.New("unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (c *oidcClient) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) {
var claims jwt.RegisteredClaims
var mapClaims []string
jwks := c.getKeyfunc()
if jwks == nil {
return claims, mapClaims, errors.New("Error initializing jwks keyfunc")
}
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
if err != nil {
return claims, mapClaims, err
}
_, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
// TODO: decode mapClaims to sth readable
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
if !claims.VerifyIssuer(c.issuer, true) {
vErr := jwt.ValidationError{}
vErr.Inner = jwt.ErrTokenInvalidIssuer
vErr.Errors |= jwt.ValidationErrorIssuer
return claims, mapClaims, vErr
}
return claims, mapClaims, nil
}
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
err := json.Unmarshal(body, &v)
if err == nil {

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
)
// Option defines a single option function.
@@ -17,6 +18,11 @@ type Options struct {
Logger log.Logger
// The OpenID Connect Issuer URL
OidcIssuer string
// JWKSOptions to use when retrieving keys
JWKSOptions config.JWKS
// AccessTokenVerifyMethod to use when verifying access tokens
// TODO pass a function or interface to verify? an AccessTokenVerifier?
AccessTokenVerifyMethod string
}
// newOptions initializes the available default options.
@@ -44,8 +50,19 @@ func WithLogger(val log.Logger) Option {
}
}
// WithAccessTokenVerifyMethod provides a function to set the accessTokenVerifyMethod option.
func WithAccessTokenVerifyMethod(val string) Option {
return func(o *Options) {
o.AccessTokenVerifyMethod = val
}
}
func WithHTTPClient(val *http.Client) Option {
return func(o *Options) {
o.HTTPClient = val
}
}
func WithJWKSOptions(val config.JWKS) Option {
return func(o *Options) {
o.JWKSOptions = val
}
}

View File

@@ -291,12 +291,12 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config,
middleware.DefaultAccessTokenTTL(cfg.OIDC.UserinfoCache.TTL),
middleware.HTTPClient(oidcHTTPClient),
middleware.OIDCIss(cfg.OIDC.Issuer),
middleware.JWKSOptions(cfg.OIDC.JWKS),
middleware.AccessTokenVerifyMethod(cfg.OIDC.AccessTokenVerifyMethod),
middleware.OIDCClient(oidc.NewOIDCClient(
oidc.WithAccessTokenVerifyMethod(cfg.OIDC.AccessTokenVerifyMethod),
oidc.WithLogger(logger),
oidc.WithHTTPClient(oidcHTTPClient),
oidc.WithOidcIssuer(cfg.OIDC.Issuer),
oidc.WithJWKSOptions(cfg.OIDC.JWKS),
)),
))
authenticators = append(authenticators, middleware.PublicShareAuthenticator{

View File

@@ -6,14 +6,11 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
"github.com/MicahParks/keyfunc"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"github.com/shamaton/msgpack/v2"
@@ -39,9 +36,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
HTTPClient: options.HTTPClient,
OIDCIss: options.OIDCIss,
oidcClient: options.OIDCClient,
JWKSOptions: options.JWKS,
AccessTokenVerifyMethod: options.AccessTokenVerifyMethod,
jwksLock: &sync.Mutex{},
}
}
@@ -55,16 +50,12 @@ type OIDCAuthenticator struct {
DefaultTokenCacheTTL time.Duration
oidcClient oidc.OIDCProvider
AccessTokenVerifyMethod string
JWKSOptions config.JWKS
jwksLock *sync.Mutex
JWKS *keyfunc.JWKS
}
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
var claims map[string]interface{}
// usea 64 bytes long hash to have 256-bit collision resistance.
// use a 64 bytes long hash to have 256-bit collision resistance.
hash := make([]byte, 64)
sha3.ShakeSum256(hash, []byte(token))
encodedHash := base64.URLEncoding.EncodeToString(hash)
@@ -80,9 +71,10 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
}
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
}
// TODO: use mClaims
aClaims, mClaims, err := m.verifyAccessToken(token)
//fmt.Println(mClaims)
aClaims, mClaims, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
vals := make([]string, len(mClaims))
for k, v := range mClaims {
s, _ := base64.StdEncoding.DecodeString(v)
@@ -140,57 +132,6 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
return claims, nil
}
// TODO: update jwt lib to have access to session id, or extract the session id and return it
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, []string, error) {
var mapClaims []string
switch m.AccessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return m.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
m.Logger.Debug().Msg("Access Token verification disabled")
return jwt.RegisteredClaims{}, mapClaims, nil
default:
m.Logger.Error().Str("access_token_verify_method", m.AccessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return jwt.RegisteredClaims{}, mapClaims, errors.New("Unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) {
var claims jwt.RegisteredClaims
var mapClaims []string
jwks := m.getKeyfunc()
if jwks == nil {
return claims, mapClaims, errors.New("Error initializing jwks keyfunc")
}
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
if err != nil {
return claims, mapClaims, err
}
_, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
// TODO: decode mapClaims to sth readable
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
if !claims.VerifyIssuer(m.OIDCIss, true) {
vErr := jwt.ValidationError{}
vErr.Inner = jwt.ErrTokenInvalidIssuer
vErr.Errors |= jwt.ValidationErrorIssuer
return claims, mapClaims, vErr
}
return claims, mapClaims, nil
}
// extractExpiration tries to extract the expriration time from the access token
// If the access token does not have an exp claim it will fallback to the configured
// default expiration
@@ -212,36 +153,6 @@ func (m OIDCAuthenticator) shouldServe(req *http.Request) bool {
return strings.HasPrefix(header, _bearerPrefix)
}
func (m *OIDCAuthenticator) getKeyfunc() *keyfunc.JWKS {
m.jwksLock.Lock()
defer m.jwksLock.Unlock()
if m.JWKS == nil {
oidcMetadata, err := oidc.GetIDPMetadata(m.Logger, m.HTTPClient, m.OIDCIss)
if err != nil {
m.Logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
return nil
}
m.Logger.Debug().Str("jwks", oidcMetadata.JwksURI).Msg("discovered jwks endpoint")
options := keyfunc.Options{
Client: m.HTTPClient,
RefreshErrorHandler: func(err error) {
m.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
},
RefreshInterval: time.Minute * time.Duration(m.JWKSOptions.RefreshInterval),
RefreshRateLimit: time.Second * time.Duration(m.JWKSOptions.RefreshRateLimit),
RefreshTimeout: time.Second * time.Duration(m.JWKSOptions.RefreshTimeout),
RefreshUnknownKID: m.JWKSOptions.RefreshUnknownKID,
}
m.JWKS, err = keyfunc.Get(oidcMetadata.JwksURI, options)
if err != nil {
m.JWKS = nil
m.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
return nil
}
}
return m.JWKS
}
// Authenticate implements the authenticator interface to authenticate requests via oidc auth.
func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
// there is no bearer token on the request,
@@ -251,11 +162,6 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool)
// implement an early return here for paths we can't authenticate in this authenticator.
return nil, false
}
// Force init of jwks keyfunc if needed (contacts the .well-known and jwks endpoints on first call)
if m.AccessTokenVerifyMethod == config.AccessTokenVerificationJWT && m.getKeyfunc() == nil {
return nil, false
}
token := strings.TrimPrefix(r.Header.Get(_headerAuthorization), _bearerPrefix)
claims, err := m.getClaims(token, r)

View File

@@ -219,13 +219,6 @@ func AccessTokenVerifyMethod(method string) Option {
}
}
// JWKSOptions sets the options for fetching the JWKS from the IDP
func JWKSOptions(jo config.JWKS) Option {
return func(o *Options) {
o.JWKS = jo
}
}
// RoleQuotas sets the role quota mapping setting
func RoleQuotas(roleQuotas map[string]uint64) Option {
return func(o *Options) {