Files
opencloud/pkg/oidc/client.go
Jörn Friedrich Dreyer b07b5a1149 use plain pkg module
Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
2025-01-13 16:42:19 +01:00

374 lines
12 KiB
Go

package oidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/MicahParks/keyfunc/v2"
goidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/opencloud-eu/opencloud/pkg/log"
"github.com/opencloud-eu/opencloud/services/proxy/pkg/config"
"golang.org/x/oauth2"
)
// OIDCClient used to mock the oidc client during tests
type OIDCClient interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error)
VerifyLogoutToken(ctx context.Context, token string) (*LogoutToken, error)
}
// KeySet is a set of public JSON Web Keys that can be used to validate the signature
// of JSON web tokens. This is expected to be backed by a remote key set through
// provider metadata discovery or an in-memory set of keys delivered out-of-band.
type KeySet interface {
// VerifySignature parses the JSON web token, verifies the signature, and returns
// the raw payload. Header and claim fields are validated by other parts of the
// package. For example, the KeySet does not need to check values such as signature
// algorithm, issuer, and audience since the IDTokenVerifier validates these values
// independently.
//
// If VerifySignature makes HTTP requests to verify the token, it's expected to
// use any HTTP client associated with the context through ClientContext.
VerifySignature(ctx context.Context, jwt string) (payload []byte, err error)
}
type RegClaimsWithSID struct {
SessionID string `json:"sid"`
jwt.RegisteredClaims
}
type oidcClient struct {
// Logger to use for logging, must be set
Logger log.Logger
issuer string
provider *ProviderMetadata
providerLock *sync.Mutex
skipIssuerValidation bool
accessTokenVerifyMethod string
remoteKeySet KeySet
algorithms []string
JWKSOptions config.JWKS
JWKS *keyfunc.JWKS
jwksLock *sync.Mutex
httpClient *http.Client
}
// _supportedAlgorithms is a list of algorithms explicitly supported by this
// package. If a provider supports other algorithms, such as HS256 or none,
// those values won't be passed to the IDTokenVerifier.
var _supportedAlgorithms = map[string]bool{
RS256: true,
RS384: true,
RS512: true,
ES256: true,
ES384: true,
ES512: true,
PS256: true,
PS384: true,
PS512: true,
}
// NewOIDCClient returns an OIDClient instance for the given issuer
func NewOIDCClient(opts ...Option) OIDCClient {
options := newOptions(opts...)
return &oidcClient{
Logger: options.Logger,
issuer: options.OIDCIssuer,
httpClient: options.HTTPClient,
accessTokenVerifyMethod: options.AccessTokenVerifyMethod,
JWKSOptions: options.JWKSOptions, // TODO I don't like that we pass down config options ...
JWKS: options.JWKS,
providerLock: &sync.Mutex{},
jwksLock: &sync.Mutex{},
remoteKeySet: options.KeySet,
provider: options.ProviderMetadata,
}
}
func (c *oidcClient) lookupWellKnownOpenidConfiguration(ctx context.Context) error {
c.providerLock.Lock()
defer c.providerLock.Unlock()
if c.provider == nil {
wellKnown := strings.TrimSuffix(c.issuer, "/") + wellknownPath
req, err := http.NewRequest("GET", wellKnown, nil)
if err != nil {
return err
}
resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%s: %s", resp.Status, body)
}
var p ProviderMetadata
err = unmarshalResp(resp, body, &p)
if err != nil {
return fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}
if !c.skipIssuerValidation && p.Issuer != c.issuer {
return fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", c.issuer, p.Issuer)
}
var algs []string
for _, a := range p.IDTokenSigningAlgValuesSupported {
if _supportedAlgorithms[a] {
algs = append(algs, a)
}
}
c.provider = &p
c.algorithms = algs
c.remoteKeySet = goidc.NewRemoteKeySet(goidc.ClientContext(ctx, c.httpClient), p.JwksURI)
}
return nil
}
func (c *oidcClient) getKeyfunc() *keyfunc.JWKS {
c.jwksLock.Lock()
defer c.jwksLock.Unlock()
if c.JWKS == nil {
var err error
c.Logger.Debug().Str("jwks", c.provider.JwksURI).Msg("discovered jwks endpoint")
options := keyfunc.Options{
Client: c.httpClient,
RefreshErrorHandler: func(err error) {
c.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
},
RefreshInterval: time.Minute * time.Duration(c.JWKSOptions.RefreshInterval),
RefreshRateLimit: time.Second * time.Duration(c.JWKSOptions.RefreshRateLimit),
RefreshTimeout: time.Second * time.Duration(c.JWKSOptions.RefreshTimeout),
RefreshUnknownKID: c.JWKSOptions.RefreshUnknownKID,
}
c.JWKS, err = keyfunc.Get(c.provider.JwksURI, options)
if err != nil {
c.JWKS = nil
c.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
return nil
}
}
return c.JWKS
}
type stringAsBool bool
// Claims unmarshals the raw JSON string into a bool.
func (sb *stringAsBool) UnmarshalJSON(b []byte) error {
v, err := strconv.ParseBool(string(b))
if err != nil {
return err
}
*sb = stringAsBool(v)
return nil
}
// UserInfo represents the OpenID Connect userinfo claims.
type UserInfo struct {
Subject string `json:"sub"`
Profile string `json:"profile"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
claims []byte
}
type userInfoRaw struct {
Subject string `json:"sub"`
Profile string `json:"profile"`
Email string `json:"email"`
// Handle providers that return email_verified as a string
// https://forums.aws.amazon.com/thread.jspa?messageID=949441&#949441 and
// https://discuss.elastic.co/t/openid-error-after-authenticating-against-aws-cognito/206018/11
EmailVerified stringAsBool `json:"email_verified"`
}
// Claims unmarshals the raw JSON object claims into the provided object.
func (u *UserInfo) Claims(v interface{}) error {
if u.claims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(u.claims, v)
}
// UserInfo retrieves the userinfo from a Token
func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*UserInfo, error) {
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return nil, err
}
if c.provider.UserinfoEndpoint == "" {
return nil, errors.New("oidc: user info endpoint is not supported by this provider")
}
req, err := http.NewRequest("GET", c.provider.UserinfoEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("oidc: create GET request: %v", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("oidc: get access token: %v", err)
}
token.SetAuthHeader(req)
resp, err := c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}
ct := resp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(ct)
if err == nil && mediaType == "application/jwt" {
payload, err := c.remoteKeySet.VerifySignature(goidc.ClientContext(ctx, c.httpClient), string(body))
if err != nil {
return nil, fmt.Errorf("oidc: invalid userinfo jwt signature %v", err)
}
body = payload
}
var userInfo userInfoRaw
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err)
}
return &UserInfo{
Subject: userInfo.Subject,
Profile: userInfo.Profile,
Email: userInfo.Email,
EmailVerified: bool(userInfo.EmailVerified),
claims: body,
}, nil
}
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error) {
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return RegClaimsWithSID{}, jwt.MapClaims{}, err
}
switch c.accessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return c.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
c.Logger.Debug().Msg("Access Token verification disabled")
return RegClaimsWithSID{}, jwt.MapClaims{}, nil
default:
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return RegClaimsWithSID{}, jwt.MapClaims{}, errors.New("unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, jwt.MapClaims, error) {
var claims RegClaimsWithSID
mapClaims := jwt.MapClaims{}
jwks := c.getKeyfunc()
if jwks == nil {
return claims, mapClaims, errors.New("error initializing jwks keyfunc")
}
issuer := c.issuer
if c.provider.AccessTokenIssuer != "" {
// AD FS .well-known/openid-configuration has an optional `access_token_issuer` which takes precedence over `issuer`
// See https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/586de7dd-3385-47c7-93a2-935d9e90441c
issuer = c.provider.AccessTokenIssuer
}
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc, jwt.WithIssuer(issuer))
if err != nil {
return claims, mapClaims, err
}
_, _, err = new(jwt.Parser).ParseUnverified(token, mapClaims)
// TODO: decode mapClaims to sth readable
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
c.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
return claims, mapClaims, nil
}
func (c *oidcClient) VerifyLogoutToken(ctx context.Context, rawToken string) (*LogoutToken, error) {
var claims LogoutToken
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return nil, err
}
jwks := c.getKeyfunc()
if jwks == nil {
return nil, errors.New("error initializing jwks keyfunc")
}
// From the backchannel-logout spec: Like ID Tokens, selection of the
// algorithm used is governed by the id_token_signing_alg_values_supported
// Discovery parameter and the id_token_signed_response_alg Registration
// parameter when they are used; otherwise, the value SHOULD be the default
// of RS256
supportedSigAlgs := c.algorithms
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{RS256}
}
_, err := jwt.ParseWithClaims(rawToken, &claims, jwks.Keyfunc, jwt.WithValidMethods(supportedSigAlgs), jwt.WithIssuer(c.issuer))
if err != nil {
c.Logger.Debug().Err(err).Msg("Failed to parse logout token")
return nil, err
}
// Basic token validation has happened in ParseWithClaims (signature,
// issuer, audience, ...). Now for some logout token specific checks.
// 1. Verify that the Logout Token contains a sub Claim, a sid Claim, or both.
if claims.Subject == "" && claims.SessionId == "" {
return nil, fmt.Errorf("oidc: logout token must contain either sub or sid and MAY contain both")
}
// 2. Verify that the Logout Token contains an events Claim whose value is JSON object containing the member name http://schemas.openid.net/event/backchannel-logout.
if claims.Events.Event == nil {
return nil, fmt.Errorf("oidc: logout token must contain logout event")
}
// 3. Verify that the Logout Token does not contain a nonce Claim.
if claims.Nonce != nil {
return nil, fmt.Errorf("oidc: nonce on logout token MUST NOT be present")
}
return &claims, nil
}
func unmarshalResp(r *http.Response, body []byte, v interface{}) error {
err := json.Unmarshal(body, &v)
if err == nil {
return nil
}
ct := r.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(ct)
if err == nil && mediaType == "application/json" {
return fmt.Errorf("got Content-Type = application/json, but could not unmarshal as JSON: %v", err)
}
return fmt.Errorf("expected Content-Type = application/json, got %q: %v", ct, err)
}