mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-01-26 15:02:52 -05:00
274 lines
8.6 KiB
Go
274 lines
8.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
|
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
|
|
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
|
|
|
|
"github.com/MicahParks/keyfunc"
|
|
gOidc "github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/pkg/errors"
|
|
"github.com/shamaton/msgpack/v2"
|
|
store "go-micro.dev/v4/store"
|
|
"golang.org/x/crypto/sha3"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
const (
|
|
_headerAuthorization = "Authorization"
|
|
_bearerPrefix = "Bearer "
|
|
)
|
|
|
|
// OIDCProvider used to mock the oidc provider during tests
|
|
type OIDCProvider interface {
|
|
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*gOidc.UserInfo, error)
|
|
}
|
|
|
|
// NewOIDCAuthenticator returns a ready to use authenticator which can handle OIDC authentication.
|
|
func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
|
|
options := newOptions(opts...)
|
|
|
|
return &OIDCAuthenticator{
|
|
Logger: options.Logger,
|
|
userInfoCache: options.Cache,
|
|
DefaultTokenCacheTTL: options.DefaultAccessTokenTTL,
|
|
HTTPClient: options.HTTPClient,
|
|
OIDCIss: options.OIDCIss,
|
|
ProviderFunc: options.OIDCProviderFunc,
|
|
JWKSOptions: options.JWKS,
|
|
AccessTokenVerifyMethod: options.AccessTokenVerifyMethod,
|
|
providerLock: &sync.Mutex{},
|
|
jwksLock: &sync.Mutex{},
|
|
}
|
|
}
|
|
|
|
// OIDCAuthenticator is an authenticator responsible for OIDC authentication.
|
|
type OIDCAuthenticator struct {
|
|
Logger log.Logger
|
|
HTTPClient *http.Client
|
|
OIDCIss string
|
|
userInfoCache store.Store
|
|
DefaultTokenCacheTTL time.Duration
|
|
ProviderFunc func() (OIDCProvider, error)
|
|
AccessTokenVerifyMethod string
|
|
JWKSOptions config.JWKS
|
|
|
|
providerLock *sync.Mutex
|
|
provider OIDCProvider
|
|
|
|
jwksLock *sync.Mutex
|
|
JWKS *keyfunc.JWKS
|
|
}
|
|
|
|
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
|
|
var claims map[string]interface{}
|
|
|
|
// usea 64 bytes long hash to have 256-bit collision resistance.
|
|
hash := make([]byte, 64)
|
|
sha3.ShakeSum256(hash, []byte(token))
|
|
encodedHash := base64.URLEncoding.EncodeToString(hash)
|
|
|
|
record, err := m.userInfoCache.Read(encodedHash)
|
|
if err != nil && err != store.ErrNotFound {
|
|
m.Logger.Error().Err(err).Msg("could not read from userinfo cache")
|
|
}
|
|
if len(record) > 0 {
|
|
if err = msgpack.Unmarshal(record[0].Value, &claims); err == nil {
|
|
m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
|
return claims, nil
|
|
}
|
|
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
|
|
}
|
|
aClaims, err := m.verifyAccessToken(token)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to verify access token")
|
|
}
|
|
|
|
oauth2Token := &oauth2.Token{
|
|
AccessToken: token,
|
|
}
|
|
|
|
userInfo, err := m.getProvider().UserInfo(
|
|
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
|
|
oauth2.StaticTokenSource(oauth2Token),
|
|
)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "failed to get userinfo")
|
|
}
|
|
if err := userInfo.Claims(&claims); err != nil {
|
|
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
|
|
}
|
|
|
|
expiration := m.extractExpiration(aClaims)
|
|
go func() {
|
|
if d, err := msgpack.Marshal(claims); err != nil {
|
|
m.Logger.Error().Err(err).Msg("failed to marshal claims for userinfo cache")
|
|
} else {
|
|
err = m.userInfoCache.Write(&store.Record{
|
|
Key: encodedHash,
|
|
Value: d,
|
|
Expiry: time.Until(expiration),
|
|
})
|
|
if err != nil {
|
|
m.Logger.Error().Err(err).Msg("failed to write to userinfo cache")
|
|
}
|
|
}
|
|
}()
|
|
|
|
m.Logger.Debug().Interface("claims", claims).Msg("extracted claims")
|
|
return claims, nil
|
|
}
|
|
|
|
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, error) {
|
|
switch m.AccessTokenVerifyMethod {
|
|
case config.AccessTokenVerificationJWT:
|
|
return m.verifyAccessTokenJWT(token)
|
|
case config.AccessTokenVerificationNone:
|
|
m.Logger.Debug().Msg("Access Token verification disabled")
|
|
return jwt.RegisteredClaims{}, nil
|
|
default:
|
|
m.Logger.Error().Str("access_token_verify_method", m.AccessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
|
return jwt.RegisteredClaims{}, errors.New("Unknown Access Token Verification method")
|
|
}
|
|
}
|
|
|
|
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
|
|
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, error) {
|
|
var claims jwt.RegisteredClaims
|
|
jwks := m.getKeyfunc()
|
|
if jwks == nil {
|
|
return claims, errors.New("Error initializing jwks keyfunc")
|
|
}
|
|
|
|
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
|
|
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
|
if err != nil {
|
|
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
|
return claims, err
|
|
}
|
|
|
|
if !claims.VerifyIssuer(m.OIDCIss, true) {
|
|
vErr := jwt.ValidationError{}
|
|
vErr.Inner = jwt.ErrTokenInvalidIssuer
|
|
vErr.Errors |= jwt.ValidationErrorIssuer
|
|
return claims, vErr
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// extractExpiration tries to extract the expriration time from the access token
|
|
// If the access token does not have an exp claim it will fallback to the configured
|
|
// default expiration
|
|
func (m OIDCAuthenticator) extractExpiration(aClaims jwt.RegisteredClaims) time.Time {
|
|
defaultExpiration := time.Now().Add(m.DefaultTokenCacheTTL)
|
|
if aClaims.ExpiresAt != nil {
|
|
m.Logger.Debug().Str("exp", aClaims.ExpiresAt.String()).Msg("Expiration Time from access_token")
|
|
return aClaims.ExpiresAt.Time
|
|
}
|
|
return defaultExpiration
|
|
}
|
|
|
|
func (m OIDCAuthenticator) shouldServe(req *http.Request) bool {
|
|
if m.OIDCIss == "" {
|
|
return false
|
|
}
|
|
|
|
header := req.Header.Get(_headerAuthorization)
|
|
return strings.HasPrefix(header, _bearerPrefix)
|
|
}
|
|
|
|
func (m *OIDCAuthenticator) getKeyfunc() *keyfunc.JWKS {
|
|
m.jwksLock.Lock()
|
|
defer m.jwksLock.Unlock()
|
|
if m.JWKS == nil {
|
|
oidcMetadata, err := oidc.GetIDPMetadata(m.Logger, m.HTTPClient, m.OIDCIss)
|
|
if err != nil {
|
|
m.Logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
|
|
return nil
|
|
}
|
|
m.Logger.Debug().Str("jwks", oidcMetadata.JwksURI).Msg("discovered jwks endpoint")
|
|
options := keyfunc.Options{
|
|
Client: m.HTTPClient,
|
|
RefreshErrorHandler: func(err error) {
|
|
m.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
|
|
},
|
|
RefreshInterval: time.Minute * time.Duration(m.JWKSOptions.RefreshInterval),
|
|
RefreshRateLimit: time.Second * time.Duration(m.JWKSOptions.RefreshRateLimit),
|
|
RefreshTimeout: time.Second * time.Duration(m.JWKSOptions.RefreshTimeout),
|
|
RefreshUnknownKID: m.JWKSOptions.RefreshUnknownKID,
|
|
}
|
|
m.JWKS, err = keyfunc.Get(oidcMetadata.JwksURI, options)
|
|
if err != nil {
|
|
m.JWKS = nil
|
|
m.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
|
|
return nil
|
|
}
|
|
}
|
|
return m.JWKS
|
|
}
|
|
|
|
func (m *OIDCAuthenticator) getProvider() OIDCProvider {
|
|
m.providerLock.Lock()
|
|
defer m.providerLock.Unlock()
|
|
if m.provider == nil {
|
|
// Lazily initialize a provider
|
|
|
|
// provider needs to be cached as when it is created
|
|
// it will fetch the keys from the issuer using the .well-known
|
|
// endpoint
|
|
provider, err := m.ProviderFunc()
|
|
if err != nil {
|
|
m.Logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
|
|
return nil
|
|
}
|
|
|
|
m.provider = provider
|
|
}
|
|
return m.provider
|
|
}
|
|
|
|
// Authenticate implements the authenticator interface to authenticate requests via oidc auth.
|
|
func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
|
|
// there is no bearer token on the request,
|
|
if !m.shouldServe(r) || isPublicPath(r.URL.Path) {
|
|
// The authentication of public path requests is handled by another authenticator.
|
|
// Since we can't guarantee the order of execution of the authenticators, we better
|
|
// implement an early return here for paths we can't authenticate in this authenticator.
|
|
return nil, false
|
|
}
|
|
|
|
if m.getProvider() == nil {
|
|
return nil, false
|
|
}
|
|
|
|
// Force init of jwks keyfunc if needed (contacts the .well-known and jwks endpoints on first call)
|
|
if m.AccessTokenVerifyMethod == config.AccessTokenVerificationJWT && m.getKeyfunc() == nil {
|
|
return nil, false
|
|
}
|
|
token := strings.TrimPrefix(r.Header.Get(_headerAuthorization), _bearerPrefix)
|
|
|
|
claims, err := m.getClaims(token, r)
|
|
if err != nil {
|
|
m.Logger.Error().
|
|
Err(err).
|
|
Str("authenticator", "oidc").
|
|
Str("path", r.URL.Path).
|
|
Msg("failed to authenticate the request")
|
|
return nil, false
|
|
}
|
|
m.Logger.Debug().
|
|
Str("authenticator", "oidc").
|
|
Str("path", r.URL.Path).
|
|
Msg("successfully authenticated request")
|
|
return r.WithContext(oidc.NewContext(r.Context(), claims)), true
|
|
}
|