mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-05-24 16:41:35 -04:00
Merge pull request #4374 from owncloud/rewrite-auth-middleware
[full-ci] Rewrite of the authentication middleware
This commit is contained in:
7
changelog/unreleased/rewrite-authentication.md
Normal file
7
changelog/unreleased/rewrite-authentication.md
Normal file
@@ -0,0 +1,7 @@
|
||||
Enhancement: Rewrite of the request authentication middleware
|
||||
|
||||
There were some flaws in the authentication middleware which were resolved by this rewrite.
|
||||
This rewrite also introduced the need to manually mark certain paths as "unprotected" if
|
||||
requests to these paths must not be authenticated.
|
||||
|
||||
https://github.com/owncloud/ocis/pull/4374
|
||||
@@ -161,6 +161,43 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config)
|
||||
Timeout: time.Second * 10,
|
||||
}
|
||||
|
||||
var authenticators []middleware.Authenticator
|
||||
if cfg.EnableBasicAuth {
|
||||
logger.Warn().Msg("basic auth enabled, use only for testing or development")
|
||||
authenticators = append(authenticators, middleware.BasicAuthenticator{
|
||||
Logger: logger,
|
||||
UserProvider: userProvider,
|
||||
})
|
||||
}
|
||||
authenticators = append(authenticators, middleware.NewOIDCAuthenticator(
|
||||
logger,
|
||||
cfg.OIDC.UserinfoCache.TTL,
|
||||
oidcHTTPClient,
|
||||
cfg.OIDC.Issuer,
|
||||
func() (middleware.OIDCProvider, error) {
|
||||
// Initialize a provider by specifying the issuer URL.
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
return oidc.NewProvider(
|
||||
context.WithValue(ctx, oauth2.HTTPClient, oidcHTTPClient),
|
||||
cfg.OIDC.Issuer,
|
||||
)
|
||||
},
|
||||
cfg.OIDC.JWKS,
|
||||
cfg.OIDC.AccessTokenVerifyMethod,
|
||||
))
|
||||
authenticators = append(authenticators, middleware.PublicShareAuthenticator{
|
||||
Logger: logger,
|
||||
RevaGatewayClient: revaClient,
|
||||
})
|
||||
|
||||
authenticators = append(authenticators, middleware.SignedURLAuthenticator{
|
||||
Logger: logger,
|
||||
PreSignedURLConfig: cfg.PreSignedURL,
|
||||
UserProvider: userProvider,
|
||||
Store: storeClient,
|
||||
})
|
||||
|
||||
return alice.New(
|
||||
// first make sure we log all requests and redirect to https if necessary
|
||||
pkgmiddleware.TraceContext,
|
||||
@@ -174,38 +211,12 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config)
|
||||
oidcHTTPClient,
|
||||
),
|
||||
|
||||
// now that we established the basics, on with authentication middleware
|
||||
middleware.Authentication(
|
||||
// OIDC Options
|
||||
middleware.OIDCProviderFunc(func() (middleware.OIDCProvider, error) {
|
||||
// Initialize a provider by specifying the issuer URL.
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
return oidc.NewProvider(
|
||||
context.WithValue(ctx, oauth2.HTTPClient, oidcHTTPClient),
|
||||
cfg.OIDC.Issuer,
|
||||
)
|
||||
}),
|
||||
middleware.HTTPClient(oidcHTTPClient),
|
||||
middleware.TokenCacheSize(cfg.OIDC.UserinfoCache.Size),
|
||||
middleware.TokenCacheTTL(time.Second*time.Duration(cfg.OIDC.UserinfoCache.TTL)),
|
||||
middleware.AccessTokenVerifyMethod(cfg.OIDC.AccessTokenVerifyMethod),
|
||||
middleware.JWKSOptions(cfg.OIDC.JWKS),
|
||||
|
||||
// basic Options
|
||||
middleware.Logger(logger),
|
||||
middleware.EnableBasicAuth(cfg.EnableBasicAuth),
|
||||
middleware.UserProvider(userProvider),
|
||||
middleware.OIDCIss(cfg.OIDC.Issuer),
|
||||
middleware.UserOIDCClaim(cfg.UserOIDCClaim),
|
||||
middleware.UserCS3Claim(cfg.UserCS3Claim),
|
||||
authenticators,
|
||||
middleware.CredentialsByUserAgent(cfg.AuthMiddleware.CredentialsByUserAgent),
|
||||
),
|
||||
middleware.SignedURLAuth(
|
||||
middleware.Logger(logger),
|
||||
middleware.PreSignedURLConfig(cfg.PreSignedURL),
|
||||
middleware.UserProvider(userProvider),
|
||||
middleware.Store(storeClient),
|
||||
middleware.OIDCIss(cfg.OIDC.Issuer),
|
||||
middleware.EnableBasicAuth(cfg.EnableBasicAuth),
|
||||
),
|
||||
middleware.AccountResolver(
|
||||
middleware.Logger(logger),
|
||||
@@ -228,9 +239,5 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config)
|
||||
middleware.TokenManagerConfig(*cfg.TokenManager),
|
||||
middleware.RevaGatewayClient(revaClient),
|
||||
),
|
||||
middleware.PublicShareAuth(
|
||||
middleware.Logger(logger),
|
||||
middleware.RevaGatewayClient(revaClient),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/webdav"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -12,48 +16,156 @@ var (
|
||||
SupportedAuthStrategies []string
|
||||
|
||||
// ProxyWwwAuthenticate is a list of endpoints that do not rely on reva underlying authentication, such as ocs.
|
||||
// services that fallback to reva authentication are declared in the "frontend" command on oCIS. It is a list of strings
|
||||
// to be regexp compiled.
|
||||
ProxyWwwAuthenticate = []string{"/ocs/v[12].php/cloud/"}
|
||||
// services that fallback to reva authentication are declared in the "frontend" command on oCIS. It is a list of
|
||||
// regexp.Regexp which are safe to use concurrently.
|
||||
ProxyWwwAuthenticate = []regexp.Regexp{*regexp.MustCompile("/ocs/v[12].php/cloud/")}
|
||||
|
||||
// WWWAuthenticate captures the Www-Authenticate header string.
|
||||
WWWAuthenticate = "Www-Authenticate"
|
||||
_publicPaths = [...]string{
|
||||
"/dav/public-files/",
|
||||
"/remote.php/dav/public-files/",
|
||||
"/remote.php/ocs/apps/files_sharing/api/v1/tokeninfo/unprotected",
|
||||
"/ocs/v1.php/cloud/capabilities",
|
||||
}
|
||||
// _unprotectedPaths contains paths which don't need to be authenticated.
|
||||
_unprotectedPaths = map[string]struct{}{
|
||||
"/": {},
|
||||
"/login": {},
|
||||
"/app/list": {},
|
||||
"/config.json": {},
|
||||
"/manifest.json": {},
|
||||
"/oidc-callback.html": {},
|
||||
"/oidc-callback": {},
|
||||
"/settings.js": {},
|
||||
"/data": {},
|
||||
"/konnect/v1/userinfo": {},
|
||||
"/status.php": {},
|
||||
"/favicon.ico": {},
|
||||
"/ocs/v1.php/config": {},
|
||||
"/ocs/v2.php/config": {},
|
||||
}
|
||||
// _unprotectedPathPrefixes contains paths which don't need to be authenticated.
|
||||
_unprotectedPathPrefixes = [...]string{
|
||||
"/files",
|
||||
"/data",
|
||||
"/account",
|
||||
"/s/",
|
||||
"/external/spaces",
|
||||
"/apps/openidconnect/redirect",
|
||||
"/settings",
|
||||
"/user-management",
|
||||
"/.well-known",
|
||||
"/js",
|
||||
"/css",
|
||||
"/icons",
|
||||
"/themes",
|
||||
"/signin",
|
||||
"/konnect",
|
||||
}
|
||||
)
|
||||
|
||||
// userAgentLocker aids in dependency injection for helper methods. The set of fields is arbitrary and the only relation
|
||||
// they share is to fulfill their duty and lock a User-Agent to its correct challenge if configured.
|
||||
type userAgentLocker struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
locks map[string]string // locks represents a reva user-agent:challenge mapping.
|
||||
fallback string
|
||||
const (
|
||||
// WwwAuthenticate captures the Www-Authenticate header string.
|
||||
WwwAuthenticate = "Www-Authenticate"
|
||||
)
|
||||
|
||||
// Authenticator is the common interface implemented by all request authenticators.
|
||||
type Authenticator interface {
|
||||
// Authenticate is used to authenticate incoming HTTP requests.
|
||||
// The Authenticator may augment the request with user info or anything related to the
|
||||
// authentication and return the augmented request.
|
||||
Authenticate(*http.Request) (*http.Request, bool)
|
||||
}
|
||||
|
||||
// Authentication is a higher order authentication middleware.
|
||||
func Authentication(opts ...Option) func(next http.Handler) http.Handler {
|
||||
func Authentication(auths []Authenticator, opts ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(opts...)
|
||||
|
||||
configureSupportedChallenges(options)
|
||||
oidc := newOIDCAuth(options)
|
||||
basic := newBasicAuth(options)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if options.OIDCIss != "" && options.EnableBasicAuth {
|
||||
oidc(basic(next)).ServeHTTP(w, r)
|
||||
if isOIDCTokenAuth(r) || isUnprotectedPath(r) {
|
||||
// The authentication for this request is handled by the IdP.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if options.OIDCIss != "" && !options.EnableBasicAuth {
|
||||
oidc(next).ServeHTTP(w, r)
|
||||
for _, a := range auths {
|
||||
if req, ok := a.Authenticate(r); ok {
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !isPublicPath(r.URL.Path) {
|
||||
// Failed basic authentication attempts receive the Www-Authenticate header in the response
|
||||
var touch bool
|
||||
caser := cases.Title(language.Und)
|
||||
for k, v := range options.CredentialsByUserAgent {
|
||||
if strings.Contains(k, r.UserAgent()) {
|
||||
removeSuperfluousAuthenticate(w)
|
||||
w.Header().Add("Www-Authenticate", fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", caser.String(v), r.Host))
|
||||
touch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// if the request is not bound to any user agent, write all available challenges
|
||||
if !touch &&
|
||||
// This is a temporary hack... Before the authentication middleware rewrite all
|
||||
// unauthenticated requests were still handled. The reva http services then did add
|
||||
// the supported authentication headers to the response. Since we are not allowing the
|
||||
// requests to continue so far we have to do it here. But we shouldn't do it for the graph service.
|
||||
// That's the reason for this hard check here.
|
||||
!strings.HasPrefix(r.URL.Path, "/graph") {
|
||||
writeSupportedAuthenticateHeader(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
if options.OIDCIss == "" && options.EnableBasicAuth {
|
||||
basic(next).ServeHTTP(w, r)
|
||||
for _, s := range SupportedAuthStrategies {
|
||||
userAgentAuthenticateLockIn(w, r, options.CredentialsByUserAgent, s)
|
||||
}
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
// if the request is a PROPFIND return a WebDAV error code.
|
||||
// TODO: The proxy has to be smart enough to detect when a request is directed towards a webdav server
|
||||
// and react accordingly.
|
||||
if webdav.IsWebdavRequest(r) {
|
||||
b, err := webdav.Marshal(webdav.Exception{
|
||||
Code: webdav.SabredavPermissionDenied,
|
||||
Message: "Authentication error",
|
||||
})
|
||||
|
||||
webdav.HandleWebdavError(w, b, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// The token auth endpoint uses basic auth for clients, see https://openid.net/specs/openid-connect-basic-1_0.html#TokenRequest
|
||||
// > The Client MUST authenticate to the Token Endpoint using the HTTP Basic method, as described in 2.3.1 of OAuth 2.0.
|
||||
func isOIDCTokenAuth(req *http.Request) bool {
|
||||
return req.URL.Path == "/konnect/v1/token"
|
||||
}
|
||||
|
||||
func isUnprotectedPath(r *http.Request) bool {
|
||||
if _, ok := _unprotectedPaths[r.URL.Path]; ok {
|
||||
return true
|
||||
}
|
||||
for _, p := range _unprotectedPathPrefixes {
|
||||
if strings.HasPrefix(r.URL.Path, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isPublicPath(p string) bool {
|
||||
for _, pp := range _publicPaths {
|
||||
if strings.HasPrefix(p, pp) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// configureSupportedChallenges adds known authentication challenges to the current session.
|
||||
func configureSupportedChallenges(options Options) {
|
||||
if options.OIDCIss != "" {
|
||||
@@ -66,13 +178,23 @@ func configureSupportedChallenges(options Options) {
|
||||
}
|
||||
|
||||
func writeSupportedAuthenticateHeader(w http.ResponseWriter, r *http.Request) {
|
||||
for i := 0; i < len(SupportedAuthStrategies); i++ {
|
||||
w.Header().Add(WWWAuthenticate, fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", strings.Title(SupportedAuthStrategies[i]), r.Host))
|
||||
caser := cases.Title(language.Und)
|
||||
for _, s := range SupportedAuthStrategies {
|
||||
w.Header().Add(WwwAuthenticate, fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", caser.String(s), r.Host))
|
||||
}
|
||||
}
|
||||
|
||||
func removeSuperfluousAuthenticate(w http.ResponseWriter) {
|
||||
w.Header().Del(WWWAuthenticate)
|
||||
w.Header().Del(WwwAuthenticate)
|
||||
}
|
||||
|
||||
// userAgentLocker aids in dependency injection for helper methods. The set of fields is arbitrary and the only relation
|
||||
// they share is to fulfill their duty and lock a User-Agent to its correct challenge if configured.
|
||||
type userAgentLocker struct {
|
||||
w http.ResponseWriter
|
||||
r *http.Request
|
||||
locks map[string]string // locks represents a reva user-agent:challenge mapping.
|
||||
fallback string
|
||||
}
|
||||
|
||||
// userAgentAuthenticateLockIn sets Www-Authenticate according to configured user agents. This is useful for the case of
|
||||
@@ -86,49 +208,22 @@ func userAgentAuthenticateLockIn(w http.ResponseWriter, r *http.Request, locks m
|
||||
fallback: fallback,
|
||||
}
|
||||
|
||||
for i := 0; i < len(ProxyWwwAuthenticate); i++ {
|
||||
evalRequestURI(&u, i)
|
||||
for _, r := range ProxyWwwAuthenticate {
|
||||
evalRequestURI(u, r)
|
||||
}
|
||||
}
|
||||
|
||||
func evalRequestURI(l *userAgentLocker, i int) {
|
||||
r := regexp.MustCompile(ProxyWwwAuthenticate[i])
|
||||
if r.Match([]byte(l.r.RequestURI)) {
|
||||
for k, v := range l.locks {
|
||||
if strings.Contains(k, l.r.UserAgent()) {
|
||||
removeSuperfluousAuthenticate(l.w)
|
||||
l.w.Header().Add(WWWAuthenticate, fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", strings.Title(v), l.r.Host))
|
||||
return
|
||||
}
|
||||
func evalRequestURI(l userAgentLocker, r regexp.Regexp) {
|
||||
if !r.MatchString(l.r.RequestURI) {
|
||||
return
|
||||
}
|
||||
caser := cases.Title(language.Und)
|
||||
for k, v := range l.locks {
|
||||
if strings.Contains(k, l.r.UserAgent()) {
|
||||
removeSuperfluousAuthenticate(l.w)
|
||||
l.w.Header().Add(WwwAuthenticate, fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", caser.String(v), l.r.Host))
|
||||
return
|
||||
}
|
||||
l.w.Header().Add(WWWAuthenticate, fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", strings.Title(l.fallback), l.r.Host))
|
||||
}
|
||||
}
|
||||
|
||||
// newOIDCAuth returns a configured oidc middleware
|
||||
func newOIDCAuth(options Options) func(http.Handler) http.Handler {
|
||||
return OIDCAuth(
|
||||
Logger(options.Logger),
|
||||
OIDCProviderFunc(options.OIDCProviderFunc),
|
||||
HTTPClient(options.HTTPClient),
|
||||
OIDCIss(options.OIDCIss),
|
||||
TokenCacheSize(options.UserinfoCacheSize),
|
||||
TokenCacheTTL(options.UserinfoCacheTTL),
|
||||
CredentialsByUserAgent(options.CredentialsByUserAgent),
|
||||
AccessTokenVerifyMethod(options.AccessTokenVerifyMethod),
|
||||
JWKSOptions(options.JWKS),
|
||||
)
|
||||
}
|
||||
|
||||
// newBasicAuth returns a configured basic middleware
|
||||
func newBasicAuth(options Options) func(http.Handler) http.Handler {
|
||||
return BasicAuth(
|
||||
UserProvider(options.UserProvider),
|
||||
Logger(options.Logger),
|
||||
EnableBasicAuth(options.EnableBasicAuth),
|
||||
OIDCIss(options.OIDCIss),
|
||||
UserOIDCClaim(options.UserOIDCClaim),
|
||||
UserCS3Claim(options.UserCS3Claim),
|
||||
CredentialsByUserAgent(options.CredentialsByUserAgent),
|
||||
)
|
||||
l.w.Header().Add(WwwAuthenticate, fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", caser.String(l.fallback), l.r.Host))
|
||||
}
|
||||
|
||||
19
services/proxy/pkg/middleware/authentication_test.go
Normal file
19
services/proxy/pkg/middleware/authentication_test.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("authentication helpers", func() {
|
||||
DescribeTable("isPublicPath should recognize public paths",
|
||||
func(input string, expected bool) {
|
||||
isPublic := isPublicPath(input)
|
||||
Expect(isPublic).To(Equal(expected))
|
||||
},
|
||||
Entry("public files path", "/remote.php/dav/public-files/", true),
|
||||
Entry("public files path without remote.php", "/remote.php/dav/public-files/", true),
|
||||
Entry("token info path", "/remote.php/ocs/apps/files_sharing/api/v1/tokeninfo/unprotected", true),
|
||||
Entry("capabilities", "/ocs/v1.php/cloud/capabilities", true),
|
||||
)
|
||||
})
|
||||
@@ -1,145 +1,62 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/webdav"
|
||||
)
|
||||
|
||||
// BasicAuth provides a middleware to check if BasicAuth is provided
|
||||
func BasicAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
logger := options.Logger
|
||||
|
||||
if options.EnableBasicAuth {
|
||||
options.Logger.Warn().Msg("basic auth enabled, use only for testing or development")
|
||||
}
|
||||
|
||||
h := basicAuth{
|
||||
logger: logger,
|
||||
enabled: options.EnableBasicAuth,
|
||||
userProvider: options.UserProvider,
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, req *http.Request) {
|
||||
if h.isPublicLink(req) || !h.isBasicAuth(req) || h.isOIDCTokenAuth(req) {
|
||||
if !h.isPublicLink(req) {
|
||||
userAgentAuthenticateLockIn(w, req, options.CredentialsByUserAgent, "basic")
|
||||
}
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
removeSuperfluousAuthenticate(w)
|
||||
login, password, _ := req.BasicAuth()
|
||||
user, _, err := h.userProvider.Authenticate(req.Context(), login, password)
|
||||
|
||||
// touch is a user agent locking guard, when touched changes to true it indicates the User-Agent on the
|
||||
// request is configured to support only one challenge, it it remains untouched, there are no considera-
|
||||
// tions and we should write all available authentication challenges to the response.
|
||||
touch := false
|
||||
|
||||
if err != nil {
|
||||
for k, v := range options.CredentialsByUserAgent {
|
||||
if strings.Contains(k, req.UserAgent()) {
|
||||
removeSuperfluousAuthenticate(w)
|
||||
w.Header().Add("Www-Authenticate", fmt.Sprintf("%v realm=\"%s\", charset=\"UTF-8\"", strings.Title(v), req.Host))
|
||||
touch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// if the request is not bound to any user agent, write all available challenges
|
||||
if !touch {
|
||||
writeSupportedAuthenticateHeader(w, req)
|
||||
}
|
||||
|
||||
// if the request is a PROPFIND return a WebDAV error code.
|
||||
// TODO: The proxy has to be smart enough to detect when a request is directed towards a webdav server
|
||||
// and react accordingly.
|
||||
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
if webdav.IsWebdavRequest(req) {
|
||||
b, err := webdav.Marshal(webdav.Exception{
|
||||
Code: webdav.SabredavPermissionDenied,
|
||||
Message: "Authentication error",
|
||||
})
|
||||
|
||||
webdav.HandleWebdavError(w, b, err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// fake oidc claims
|
||||
claims := map[string]interface{}{
|
||||
oidc.Iss: user.Id.Idp,
|
||||
oidc.PreferredUsername: user.Username,
|
||||
oidc.Email: user.Mail,
|
||||
oidc.OwncloudUUID: user.Id.OpaqueId,
|
||||
}
|
||||
|
||||
if options.UserCS3Claim == "userid" {
|
||||
// set the custom user claim only if users will be looked up by the userid on the CS3api
|
||||
// OpaqueId contains the userid configured in STORAGE_LDAP_USER_SCHEMA_UID
|
||||
claims[options.UserOIDCClaim] = user.Id.OpaqueId
|
||||
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), claims)))
|
||||
},
|
||||
)
|
||||
}
|
||||
// BasicAuthenticator is the authenticator responsible for HTTP Basic authentication.
|
||||
type BasicAuthenticator struct {
|
||||
Logger log.Logger
|
||||
UserProvider backend.UserBackend
|
||||
UserCS3Claim string
|
||||
UserOIDCClaim string
|
||||
}
|
||||
|
||||
type basicAuth struct {
|
||||
logger log.Logger
|
||||
enabled bool
|
||||
userProvider backend.UserBackend
|
||||
}
|
||||
|
||||
func (m basicAuth) isPublicLink(req *http.Request) bool {
|
||||
login, _, ok := req.BasicAuth()
|
||||
|
||||
if !ok || login != "public" {
|
||||
return false
|
||||
// Authenticate implements the authenticator interface to authenticate requests via basic auth.
|
||||
func (m BasicAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
|
||||
if isPublicPath(r.URL.Path) {
|
||||
// The authentication of public path requests is handled by another authenticator.
|
||||
// Since we can't guarantee the order of execution of the authenticators, we better
|
||||
// implement an early return here for paths we can't authenticate in this authenticator.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
publicPaths := []string{
|
||||
"/dav/public-files/",
|
||||
"/remote.php/dav/public-files/",
|
||||
"/remote.php/ocs/apps/files_sharing/api/v1/tokeninfo/unprotected",
|
||||
"/ocs/v1.php/cloud/capabilities",
|
||||
"/data",
|
||||
}
|
||||
isPublic := false
|
||||
|
||||
for _, p := range publicPaths {
|
||||
if strings.HasPrefix(req.URL.Path, p) {
|
||||
isPublic = true
|
||||
break
|
||||
}
|
||||
login, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return isPublic
|
||||
}
|
||||
user, _, err := m.UserProvider.Authenticate(r.Context(), login, password)
|
||||
if err != nil {
|
||||
m.Logger.Error().
|
||||
Err(err).
|
||||
Str("authenticator", "basic").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("failed to authenticate request")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// The token auth endpoint uses basic auth for clients, see https://openid.net/specs/openid-connect-basic-1_0.html#TokenRequest
|
||||
// > The Client MUST authenticate to the Token Endpoint using the HTTP Basic method, as described in 2.3.1 of OAuth 2.0.
|
||||
func (m basicAuth) isOIDCTokenAuth(req *http.Request) bool {
|
||||
return req.URL.Path == "/konnect/v1/token"
|
||||
}
|
||||
// fake oidc claims
|
||||
claims := map[string]interface{}{
|
||||
oidc.Iss: user.Id.Idp,
|
||||
oidc.PreferredUsername: user.Username,
|
||||
oidc.Email: user.Mail,
|
||||
oidc.OwncloudUUID: user.Id.OpaqueId,
|
||||
}
|
||||
|
||||
func (m basicAuth) isBasicAuth(req *http.Request) bool {
|
||||
_, _, ok := req.BasicAuth()
|
||||
return m.enabled && ok
|
||||
if m.UserCS3Claim == "userid" {
|
||||
// set the custom user claim only if users will be looked up by the userid on the CS3api
|
||||
// OpaqueId contains the userid configured in STORAGE_LDAP_USER_SCHEMA_UID
|
||||
claims[m.UserOIDCClaim] = user.Id.OpaqueId
|
||||
|
||||
}
|
||||
m.Logger.Debug().
|
||||
Str("authenticator", "basic").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("successfully authenticated request")
|
||||
return r.WithContext(oidc.NewContext(r.Context(), claims)), true
|
||||
}
|
||||
|
||||
@@ -1,40 +1,68 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
userv1beta1 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend/test"
|
||||
)
|
||||
|
||||
/**/
|
||||
|
||||
func TestBasicAuth__isPublicLink(t *testing.T) {
|
||||
tests := []struct {
|
||||
url string
|
||||
username string
|
||||
expected bool
|
||||
}{
|
||||
{url: "/remote.php/dav/public-files/", username: "", expected: false},
|
||||
{url: "/remote.php/dav/public-files/", username: "abc", expected: false},
|
||||
{url: "/remote.php/dav/public-files/", username: "private", expected: false},
|
||||
{url: "/remote.php/dav/public-files/", username: "public", expected: true},
|
||||
{url: "/ocs/v1.php/cloud/capabilities", username: "", expected: false},
|
||||
{url: "/ocs/v1.php/cloud/capabilities", username: "abc", expected: false},
|
||||
{url: "/ocs/v1.php/cloud/capabilities", username: "private", expected: false},
|
||||
{url: "/ocs/v1.php/cloud/capabilities", username: "public", expected: true},
|
||||
{url: "/ocs/v1.php/cloud/users/admin", username: "public", expected: false},
|
||||
}
|
||||
ba := basicAuth{}
|
||||
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest("", tt.url, nil)
|
||||
|
||||
if tt.username != "" {
|
||||
req.SetBasicAuth(tt.username, "")
|
||||
var _ = Describe("Authenticating requests", Label("BasicAuthenticator"), func() {
|
||||
var authenticator Authenticator
|
||||
BeforeEach(func() {
|
||||
authenticator = BasicAuthenticator{
|
||||
Logger: log.NewLogger(),
|
||||
UserProvider: &test.UserBackendMock{
|
||||
AuthenticateFunc: func(ctx context.Context, username, password string) (*userv1beta1.User, string, error) {
|
||||
var user *userv1beta1.User
|
||||
if username == "testuser" && password == "testpassword" {
|
||||
user = &userv1beta1.User{
|
||||
Id: &userv1beta1.UserId{
|
||||
Idp: "IdpId",
|
||||
OpaqueId: "OpaqueId",
|
||||
},
|
||||
Username: "testuser",
|
||||
Mail: "testuser@example.com",
|
||||
}
|
||||
return user, "", nil
|
||||
}
|
||||
return nil, "", backend.ErrAccountNotFound
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
result := ba.isPublicLink(req)
|
||||
if result != tt.expected {
|
||||
t.Errorf("with %s expected %t got %t", tt.url, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
When("the request contains correct data", func() {
|
||||
It("should successfully authenticate", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/example/path", http.NoBody)
|
||||
req.SetBasicAuth("testuser", "testpassword")
|
||||
|
||||
req2, valid := authenticator.Authenticate(req)
|
||||
|
||||
Expect(valid).To(Equal(true))
|
||||
Expect(req2).ToNot(BeNil())
|
||||
})
|
||||
It("adds claims to the request context", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/example/path", http.NoBody)
|
||||
req.SetBasicAuth("testuser", "testpassword")
|
||||
|
||||
req2, valid := authenticator.Authenticate(req)
|
||||
Expect(valid).To(Equal(true))
|
||||
|
||||
claims := oidc.FromContext(req2.Context())
|
||||
Expect(claims).ToNot(BeNil())
|
||||
Expect(claims[oidc.Iss]).To(Equal("IdpId"))
|
||||
Expect(claims[oidc.PreferredUsername]).To(Equal("testuser"))
|
||||
Expect(claims[oidc.Email]).To(Equal("testuser@example.com"))
|
||||
Expect(claims[oidc.OwncloudUUID]).To(Equal("OpaqueId"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
13
services/proxy/pkg/middleware/middleware_suite_test.go
Normal file
13
services/proxy/pkg/middleware/middleware_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware Suite")
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -17,90 +16,63 @@ import (
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
|
||||
osync "github.com/owncloud/ocis/v2/ocis-pkg/sync"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
_headerAuthorization = "Authorization"
|
||||
_bearerPrefix = "Bearer "
|
||||
)
|
||||
|
||||
// OIDCProvider used to mock the oidc provider during tests
|
||||
type OIDCProvider interface {
|
||||
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*gOidc.UserInfo, error)
|
||||
}
|
||||
|
||||
// OIDCAuth provides a middleware to check access secured by a static token.
|
||||
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
tokenCache := osync.NewCache(options.UserinfoCacheSize)
|
||||
|
||||
h := oidcAuth{
|
||||
logger: options.Logger,
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
// NewOIDCAuthenticator returns a ready to use authenticator which can handle OIDC authentication.
|
||||
func NewOIDCAuthenticator(logger log.Logger, tokenCacheTTL int, oidcHTTPClient *http.Client, oidcIss string, providerFunc func() (OIDCProvider, error),
|
||||
jwksOptions config.JWKS, accessTokenVerifyMethod string) OIDCAuthenticator {
|
||||
tokenCache := osync.NewCache(tokenCacheTTL)
|
||||
return OIDCAuthenticator{
|
||||
Logger: logger,
|
||||
tokenCache: &tokenCache,
|
||||
tokenCacheTTL: options.UserinfoCacheTTL,
|
||||
accessTokenVerifyMethod: options.AccessTokenVerifyMethod,
|
||||
jwksOptions: options.JWKS,
|
||||
jwksLock: &sync.Mutex{},
|
||||
TokenCacheTTL: time.Duration(tokenCacheTTL),
|
||||
HTTPClient: oidcHTTPClient,
|
||||
OIDCIss: oidcIss,
|
||||
ProviderFunc: providerFunc,
|
||||
JWKSOptions: jwksOptions,
|
||||
AccessTokenVerifyMethod: accessTokenVerifyMethod,
|
||||
providerLock: &sync.Mutex{},
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
// there is no bearer token on the request,
|
||||
if !h.shouldServe(req) {
|
||||
// oidc supported but token not present, add header and handover to the next middleware.
|
||||
userAgentAuthenticateLockIn(w, req, options.CredentialsByUserAgent, "bearer")
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
if h.getProvider() == nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Force init of jwks keyfunc if needed (contacts the .well-known and jwks endpoints on first call)
|
||||
if h.accessTokenVerifyMethod == config.AccessTokenVerificationJWT && h.getKeyfunc() == nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
|
||||
claims, status := h.getClaims(token, req)
|
||||
if status != 0 {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
|
||||
// inject claims to the request context for the account_resolver middleware.
|
||||
next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), claims)))
|
||||
})
|
||||
jwksLock: &sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
type oidcAuth struct {
|
||||
logger log.Logger
|
||||
provider OIDCProvider
|
||||
providerLock *sync.Mutex
|
||||
jwksOptions config.JWKS
|
||||
jwks *keyfunc.JWKS
|
||||
jwksLock *sync.Mutex
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
// OIDCAuthenticator is an authenticator responsible for OIDC authentication.
|
||||
type OIDCAuthenticator struct {
|
||||
Logger log.Logger
|
||||
HTTPClient *http.Client
|
||||
OIDCIss string
|
||||
tokenCache *osync.Cache
|
||||
tokenCacheTTL time.Duration
|
||||
accessTokenVerifyMethod string
|
||||
TokenCacheTTL time.Duration
|
||||
ProviderFunc func() (OIDCProvider, error)
|
||||
AccessTokenVerifyMethod string
|
||||
JWKSOptions config.JWKS
|
||||
|
||||
providerLock *sync.Mutex
|
||||
provider OIDCProvider
|
||||
|
||||
jwksLock *sync.Mutex
|
||||
JWKS *keyfunc.JWKS
|
||||
}
|
||||
|
||||
func (m oidcAuth) getClaims(token string, req *http.Request) (claims map[string]interface{}, status int) {
|
||||
func (m OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
|
||||
var claims map[string]interface{}
|
||||
hit := m.tokenCache.Load(token)
|
||||
if hit == nil {
|
||||
aClaims, err := m.verifyAccessToken(token)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("Failed to verify access token")
|
||||
status = http.StatusUnauthorized
|
||||
return
|
||||
return nil, errors.Wrap(err, "failed to verify access token")
|
||||
}
|
||||
|
||||
oauth2Token := &oauth2.Token{
|
||||
@@ -108,52 +80,46 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims map[string]
|
||||
}
|
||||
|
||||
userInfo, err := m.getProvider().UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("Failed to get userinfo")
|
||||
status = http.StatusUnauthorized
|
||||
return
|
||||
return nil, errors.Wrap(err, "failed to get userinfo")
|
||||
}
|
||||
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
status = http.StatusInternalServerError
|
||||
return
|
||||
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
|
||||
}
|
||||
|
||||
expiration := m.extractExpiration(aClaims)
|
||||
m.tokenCache.Store(token, claims, expiration)
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
|
||||
return
|
||||
m.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if claims, ok = hit.V.(map[string]interface{}); !ok {
|
||||
status = http.StatusInternalServerError
|
||||
return
|
||||
return nil, errors.New("failed to cast claims from the cache")
|
||||
}
|
||||
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
||||
return
|
||||
m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (m oidcAuth) verifyAccessToken(token string) (jwt.RegisteredClaims, error) {
|
||||
switch m.accessTokenVerifyMethod {
|
||||
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, error) {
|
||||
switch m.AccessTokenVerifyMethod {
|
||||
case config.AccessTokenVerificationJWT:
|
||||
return m.verifyAccessTokenJWT(token)
|
||||
case config.AccessTokenVerificationNone:
|
||||
m.logger.Debug().Msg("Access Token verification disabled")
|
||||
m.Logger.Debug().Msg("Access Token verification disabled")
|
||||
return jwt.RegisteredClaims{}, nil
|
||||
default:
|
||||
m.logger.Error().Str("access_token_verify_method", m.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
||||
m.Logger.Error().Str("access_token_verify_method", m.AccessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
||||
return jwt.RegisteredClaims{}, errors.New("Unknown Access Token Verification method")
|
||||
}
|
||||
}
|
||||
|
||||
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
|
||||
func (m oidcAuth) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, error) {
|
||||
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, error) {
|
||||
var claims jwt.RegisteredClaims
|
||||
jwks := m.getKeyfunc()
|
||||
if jwks == nil {
|
||||
@@ -161,13 +127,13 @@ func (m oidcAuth) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, erro
|
||||
}
|
||||
|
||||
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
|
||||
m.logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
||||
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
||||
if err != nil {
|
||||
m.logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
||||
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
||||
return claims, err
|
||||
}
|
||||
|
||||
if !claims.VerifyIssuer(m.oidcIss, true) {
|
||||
if !claims.VerifyIssuer(m.OIDCIss, true) {
|
||||
vErr := jwt.ValidationError{}
|
||||
vErr.Inner = jwt.ErrTokenInvalidIssuer
|
||||
vErr.Errors |= jwt.ValidationErrorIssuer
|
||||
@@ -180,89 +146,80 @@ func (m oidcAuth) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, erro
|
||||
// extractExpiration tries to extract the expriration time from the access token
|
||||
// If the access token does not have an exp claim it will fallback to the configured
|
||||
// default expiration
|
||||
func (m oidcAuth) extractExpiration(aClaims jwt.RegisteredClaims) time.Time {
|
||||
defaultExpiration := time.Now().Add(m.tokenCacheTTL)
|
||||
func (m OIDCAuthenticator) extractExpiration(aClaims jwt.RegisteredClaims) time.Time {
|
||||
defaultExpiration := time.Now().Add(m.TokenCacheTTL)
|
||||
if aClaims.ExpiresAt != nil {
|
||||
m.logger.Debug().Str("exp", aClaims.ExpiresAt.String()).Msg("Expiration Time from access_token")
|
||||
m.Logger.Debug().Str("exp", aClaims.ExpiresAt.String()).Msg("Expiration Time from access_token")
|
||||
return aClaims.ExpiresAt.Time
|
||||
}
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
func (m oidcAuth) shouldServe(req *http.Request) bool {
|
||||
header := req.Header.Get("Authorization")
|
||||
|
||||
if m.oidcIss == "" {
|
||||
func (m OIDCAuthenticator) shouldServe(req *http.Request) bool {
|
||||
if m.OIDCIss == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// todo: looks dirty, check later
|
||||
// TODO: make a PR to coreos/go-oidc for exposing userinfo endpoint on provider, see https://github.com/coreos/go-oidc/issues/248
|
||||
for _, ignoringPath := range []string{"/konnect/v1/userinfo", "/status.php"} {
|
||||
if req.URL.Path == ignoringPath {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return strings.HasPrefix(header, "Bearer ")
|
||||
header := req.Header.Get(_headerAuthorization)
|
||||
return strings.HasPrefix(header, _bearerPrefix)
|
||||
}
|
||||
|
||||
type jwksJSON struct {
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
}
|
||||
|
||||
func (m *oidcAuth) getKeyfunc() *keyfunc.JWKS {
|
||||
func (m OIDCAuthenticator) getKeyfunc() *keyfunc.JWKS {
|
||||
m.jwksLock.Lock()
|
||||
defer m.jwksLock.Unlock()
|
||||
if m.jwks == nil {
|
||||
wellKnown := strings.TrimSuffix(m.oidcIss, "/") + "/.well-known/openid-configuration"
|
||||
if m.JWKS == nil {
|
||||
wellKnown := strings.TrimSuffix(m.OIDCIss, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
resp, err := m.httpClient.Get(wellKnown)
|
||||
resp, err := m.HTTPClient.Get(wellKnown)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("Failed to set request for .well-known/openid-configuration")
|
||||
m.Logger.Error().Err(err).Msg("Failed to set request for .well-known/openid-configuration")
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("unable to read discovery response body")
|
||||
m.Logger.Error().Err(err).Msg("unable to read discovery response body")
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
m.logger.Error().Str("status", resp.Status).Str("body", string(body)).Msg("error requesting openid-configuration")
|
||||
m.Logger.Error().Str("status", resp.Status).Str("body", string(body)).Msg("error requesting openid-configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
var j jwksJSON
|
||||
err = json.Unmarshal(body, &j)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
|
||||
m.Logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
|
||||
return nil
|
||||
}
|
||||
m.logger.Debug().Str("jwks", j.JWKSURL).Msg("discovered jwks endpoint")
|
||||
m.Logger.Debug().Str("jwks", j.JWKSURL).Msg("discovered jwks endpoint")
|
||||
options := keyfunc.Options{
|
||||
Client: m.httpClient,
|
||||
Client: m.HTTPClient,
|
||||
RefreshErrorHandler: func(err error) {
|
||||
m.logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
|
||||
m.Logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
|
||||
},
|
||||
RefreshInterval: time.Minute * time.Duration(m.jwksOptions.RefreshInterval),
|
||||
RefreshRateLimit: time.Second * time.Duration(m.jwksOptions.RefreshRateLimit),
|
||||
RefreshTimeout: time.Second * time.Duration(m.jwksOptions.RefreshTimeout),
|
||||
RefreshUnknownKID: m.jwksOptions.RefreshUnknownKID,
|
||||
RefreshInterval: time.Minute * time.Duration(m.JWKSOptions.RefreshInterval),
|
||||
RefreshRateLimit: time.Second * time.Duration(m.JWKSOptions.RefreshRateLimit),
|
||||
RefreshTimeout: time.Second * time.Duration(m.JWKSOptions.RefreshTimeout),
|
||||
RefreshUnknownKID: m.JWKSOptions.RefreshUnknownKID,
|
||||
}
|
||||
m.jwks, err = keyfunc.Get(j.JWKSURL, options)
|
||||
m.JWKS, err = keyfunc.Get(j.JWKSURL, options)
|
||||
if err != nil {
|
||||
m.jwks = nil
|
||||
m.logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
|
||||
m.JWKS = nil
|
||||
m.Logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return m.jwks
|
||||
return m.JWKS
|
||||
}
|
||||
|
||||
func (m *oidcAuth) getProvider() OIDCProvider {
|
||||
func (m OIDCAuthenticator) getProvider() OIDCProvider {
|
||||
m.providerLock.Lock()
|
||||
defer m.providerLock.Unlock()
|
||||
if m.provider == nil {
|
||||
@@ -271,9 +228,9 @@ func (m *oidcAuth) getProvider() OIDCProvider {
|
||||
// provider needs to be cached as when it is created
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
provider, err := m.providerFunc()
|
||||
provider, err := m.ProviderFunc()
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
|
||||
m.Logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -281,3 +238,39 @@ func (m *oidcAuth) getProvider() OIDCProvider {
|
||||
}
|
||||
return m.provider
|
||||
}
|
||||
|
||||
// Authenticate implements the authenticator interface to authenticate requests via oidc auth.
|
||||
func (m OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
|
||||
// there is no bearer token on the request,
|
||||
if !m.shouldServe(r) || isPublicPath(r.URL.Path) {
|
||||
// The authentication of public path requests is handled by another authenticator.
|
||||
// Since we can't guarantee the order of execution of the authenticators, we better
|
||||
// implement an early return here for paths we can't authenticate in this authenticator.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if m.getProvider() == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Force init of jwks keyfunc if needed (contacts the .well-known and jwks endpoints on first call)
|
||||
if m.AccessTokenVerifyMethod == config.AccessTokenVerificationJWT && m.getKeyfunc() == nil {
|
||||
return nil, false
|
||||
}
|
||||
token := strings.TrimPrefix(r.Header.Get(_headerAuthorization), _bearerPrefix)
|
||||
|
||||
claims, err := m.getClaims(token, r)
|
||||
if err != nil {
|
||||
m.Logger.Error().
|
||||
Err(err).
|
||||
Str("authenticator", "oidc").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("failed to authenticate the request")
|
||||
return nil, false
|
||||
}
|
||||
m.Logger.Debug().
|
||||
Str("authenticator", "oidc").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("successfully authenticated request")
|
||||
return r.WithContext(oidc.NewContext(r.Context(), claims)), true
|
||||
}
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOIDCAuthMiddleware(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
m := OIDCAuth(
|
||||
Logger(log.NewLogger()),
|
||||
OIDCProviderFunc(func() (OIDCProvider, error) {
|
||||
return mockOP(false), nil
|
||||
}),
|
||||
OIDCIss("https://localhost:9200"),
|
||||
AccessTokenVerifyMethod(config.AccessTokenVerificationNone),
|
||||
)(next)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "https://idp.example.com", nil)
|
||||
r.Header.Set("Authorization", "Bearer sometoken")
|
||||
w := httptest.NewRecorder()
|
||||
m.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected an internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
type mockOIDCProvider struct {
|
||||
UserInfoFunc func(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error)
|
||||
}
|
||||
|
||||
// UserInfo will panic if the function has been called, but not mocked
|
||||
func (m mockOIDCProvider) UserInfo(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
if m.UserInfoFunc != nil {
|
||||
return m.UserInfoFunc(ctx, ts)
|
||||
}
|
||||
|
||||
panic("UserInfo was called in test but not mocked")
|
||||
}
|
||||
|
||||
func mockOP(retErr bool) OIDCProvider {
|
||||
if retErr {
|
||||
return &mockOIDCProvider{
|
||||
UserInfoFunc: func(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
return nil, fmt.Errorf("error returned by mockOIDCProvider UserInfo")
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
return &mockOIDCProvider{
|
||||
UserInfoFunc: func(ctx context.Context, ts oauth2.TokenSource) (*oidc.UserInfo, error) {
|
||||
ui := &oidc.UserInfo{
|
||||
// claims: private ...
|
||||
}
|
||||
return ui, nil
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
@@ -5,68 +5,83 @@ import (
|
||||
"strings"
|
||||
|
||||
gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
||||
)
|
||||
|
||||
const (
|
||||
headerRevaAccessToken = "x-access-token"
|
||||
_headerRevaAccessToken = "x-access-token"
|
||||
headerShareToken = "public-token"
|
||||
basicAuthPasswordPrefix = "password|"
|
||||
authenticationType = "publicshares"
|
||||
|
||||
_paramSignature = "signature"
|
||||
_paramExpiration = "expiration"
|
||||
)
|
||||
|
||||
// PublicShareAuth ...
|
||||
func PublicShareAuth(opts ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(opts...)
|
||||
logger := options.Logger
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
shareToken := r.Header.Get(headerShareToken)
|
||||
if shareToken == "" {
|
||||
shareToken = r.URL.Query().Get(headerShareToken)
|
||||
}
|
||||
|
||||
// Currently we only want to authenticate app open request coming from public shares.
|
||||
if shareToken == "" {
|
||||
// Don't authenticate
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var sharePassword string
|
||||
if signature := r.URL.Query().Get("signature"); signature != "" {
|
||||
expiration := r.URL.Query().Get("expiration")
|
||||
if expiration == "" {
|
||||
logger.Warn().Str("signature", signature).Msg("cannot do signature auth without the expiration")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
sharePassword = strings.Join([]string{"signature", signature, expiration}, "|")
|
||||
} else {
|
||||
// We can ignore the username since it is always set to "public" in public shares.
|
||||
_, password, ok := r.BasicAuth()
|
||||
|
||||
sharePassword = basicAuthPasswordPrefix
|
||||
if ok {
|
||||
sharePassword += password
|
||||
}
|
||||
}
|
||||
|
||||
authResp, err := options.RevaGatewayClient.Authenticate(r.Context(), &gateway.AuthenticateRequest{
|
||||
Type: authenticationType,
|
||||
ClientId: shareToken,
|
||||
ClientSecret: sharePassword,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Debug().Err(err).Str("public_share_token", shareToken).Msg("could not authenticate public share")
|
||||
// try another middleware
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
r.Header.Add(headerRevaAccessToken, authResp.Token)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
// PublicShareAuthenticator is the authenticator which can authenticate public share requests.
|
||||
// It will add the share owner into the request context.
|
||||
type PublicShareAuthenticator struct {
|
||||
Logger log.Logger
|
||||
RevaGatewayClient gateway.GatewayAPIClient
|
||||
}
|
||||
|
||||
// Authenticate implements the authenticator interface to authenticate requests via public share auth.
|
||||
func (a PublicShareAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
|
||||
if !isPublicPath(r.URL.Path) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
query := r.URL.Query()
|
||||
shareToken := r.Header.Get(headerShareToken)
|
||||
if shareToken == "" {
|
||||
shareToken = query.Get(headerShareToken)
|
||||
}
|
||||
|
||||
if shareToken == "" {
|
||||
// If the share token is not set then we don't need to inject the user to
|
||||
// the request context so we can just continue with the request.
|
||||
return r, true
|
||||
}
|
||||
|
||||
var sharePassword string
|
||||
if signature := query.Get(_paramSignature); signature != "" {
|
||||
expiration := query.Get(_paramExpiration)
|
||||
if expiration == "" {
|
||||
a.Logger.Warn().Str("signature", signature).Msg("cannot do signature auth without the expiration")
|
||||
return nil, false
|
||||
}
|
||||
sharePassword = strings.Join([]string{"signature", signature, expiration}, "|")
|
||||
} else {
|
||||
// We can ignore the username since it is always set to "public" in public shares.
|
||||
_, password, ok := r.BasicAuth()
|
||||
|
||||
sharePassword = basicAuthPasswordPrefix
|
||||
if ok {
|
||||
sharePassword += password
|
||||
}
|
||||
}
|
||||
|
||||
authResp, err := a.RevaGatewayClient.Authenticate(r.Context(), &gateway.AuthenticateRequest{
|
||||
Type: authenticationType,
|
||||
ClientId: shareToken,
|
||||
ClientSecret: sharePassword,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
a.Logger.Error().
|
||||
Err(err).
|
||||
Str("authenticator", "public_share").
|
||||
Str("public_share_token", shareToken).
|
||||
Str("path", r.URL.Path).
|
||||
Msg("failed to authenticate request")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
r.Header.Add(_headerRevaAccessToken, authResp.Token)
|
||||
|
||||
a.Logger.Debug().
|
||||
Str("authenticator", "public_share").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("successfully authenticated request")
|
||||
return r, true
|
||||
}
|
||||
|
||||
78
services/proxy/pkg/middleware/public_share_auth_test.go
Normal file
78
services/proxy/pkg/middleware/public_share_auth_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
gatewayv1beta1 "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
|
||||
rpcv1beta1 "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var _ = Describe("Authenticating requests", Label("PublicShareAuthenticator"), func() {
|
||||
var authenticator Authenticator
|
||||
BeforeEach(func() {
|
||||
authenticator = PublicShareAuthenticator{
|
||||
Logger: log.NewLogger(),
|
||||
RevaGatewayClient: mockGatewayClient{
|
||||
AuthenticateFunc: func(authType, clientID, clientSecret string) (string, rpcv1beta1.Code) {
|
||||
if authType != "publicshares" {
|
||||
return "", rpcv1beta1.Code_CODE_NOT_FOUND
|
||||
}
|
||||
|
||||
if clientID == "sharetoken" && (clientSecret == "password|examples3cr3t" || clientSecret == "signature|examplesignature|exampleexpiration") {
|
||||
return "exampletoken", rpcv1beta1.Code_CODE_OK
|
||||
}
|
||||
|
||||
return "", rpcv1beta1.Code_CODE_NOT_FOUND
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
When("the request contains correct data", func() {
|
||||
Context("using password authentication", func() {
|
||||
It("should successfully authenticate", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/dav/public-files/?public-token=sharetoken", http.NoBody)
|
||||
req.SetBasicAuth("public", "examples3cr3t")
|
||||
|
||||
req2, valid := authenticator.Authenticate(req)
|
||||
|
||||
Expect(valid).To(Equal(true))
|
||||
Expect(req2).ToNot(BeNil())
|
||||
|
||||
h := req2.Header
|
||||
Expect(h.Get(_headerRevaAccessToken)).To(Equal("exampletoken"))
|
||||
})
|
||||
})
|
||||
Context("using signature authentication", func() {
|
||||
It("should successfully authenticate", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/dav/public-files/?public-token=sharetoken&signature=examplesignature&expiration=exampleexpiration", http.NoBody)
|
||||
|
||||
req2, valid := authenticator.Authenticate(req)
|
||||
|
||||
Expect(valid).To(Equal(true))
|
||||
Expect(req2).ToNot(BeNil())
|
||||
|
||||
h := req2.Header
|
||||
Expect(h.Get(_headerRevaAccessToken)).To(Equal("exampletoken"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
type mockGatewayClient struct {
|
||||
gatewayv1beta1.GatewayAPIClient
|
||||
AuthenticateFunc func(authType, clientID, clientSecret string) (string, rpcv1beta1.Code)
|
||||
}
|
||||
|
||||
func (c mockGatewayClient) Authenticate(ctx context.Context, in *gatewayv1beta1.AuthenticateRequest, opts ...grpc.CallOption) (*gatewayv1beta1.AuthenticateResponse, error) {
|
||||
token, code := c.AuthenticateFunc(in.GetType(), in.GetClientId(), in.GetClientSecret())
|
||||
return &gatewayv1beta1.AuthenticateResponse{
|
||||
Status: &rpcv1beta1.Status{Code: code},
|
||||
Token: token,
|
||||
}, nil
|
||||
}
|
||||
@@ -20,61 +20,40 @@ import (
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
// SignedURLAuth provides a middleware to check access secured by a signed URL.
|
||||
func SignedURLAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
const (
|
||||
_paramOCSignature = "OC-Signature"
|
||||
_paramOCCredential = "OC-Credential"
|
||||
_paramOCDate = "OC-Date"
|
||||
_paramOCExpires = "OC-Expires"
|
||||
_paramOCVerb = "OC-Verb"
|
||||
)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return &signedURLAuth{
|
||||
next: next,
|
||||
logger: options.Logger,
|
||||
preSignedURLConfig: options.PreSignedURLConfig,
|
||||
store: options.Store,
|
||||
userProvider: options.UserProvider,
|
||||
}
|
||||
var (
|
||||
_requiredParams = [...]string{
|
||||
_paramOCSignature,
|
||||
_paramOCCredential,
|
||||
_paramOCDate,
|
||||
_paramOCExpires,
|
||||
_paramOCVerb,
|
||||
}
|
||||
)
|
||||
|
||||
// SignedURLAuthenticator is the authenticator responsible for authenticating signed URL requests.
|
||||
type SignedURLAuthenticator struct {
|
||||
Logger log.Logger
|
||||
PreSignedURLConfig config.PreSignedURL
|
||||
UserProvider backend.UserBackend
|
||||
Store storesvc.StoreService
|
||||
}
|
||||
|
||||
type signedURLAuth struct {
|
||||
next http.Handler
|
||||
logger log.Logger
|
||||
preSignedURLConfig config.PreSignedURL
|
||||
userProvider backend.UserBackend
|
||||
store storesvc.StoreService
|
||||
}
|
||||
|
||||
func (m signedURLAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if !m.shouldServe(req) {
|
||||
m.next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
user, _, err := m.userProvider.GetUserByClaims(req.Context(), "username", req.URL.Query().Get("OC-Credential"), true)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("Could not get user by claim")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
ctx := revactx.ContextSetUser(req.Context(), user)
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
if err := m.validate(req); err != nil {
|
||||
http.Error(w, "Invalid url signature", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
m.next.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (m signedURLAuth) shouldServe(req *http.Request) bool {
|
||||
if !m.preSignedURLConfig.Enabled {
|
||||
func (m SignedURLAuthenticator) shouldServe(req *http.Request) bool {
|
||||
if !m.PreSignedURLConfig.Enabled {
|
||||
return false
|
||||
}
|
||||
return req.URL.Query().Get("OC-Signature") != ""
|
||||
return req.URL.Query().Get(_paramOCSignature) != ""
|
||||
}
|
||||
|
||||
func (m signedURLAuth) validate(req *http.Request) (err error) {
|
||||
func (m SignedURLAuthenticator) validate(req *http.Request) (err error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if ok, err := m.allRequiredParametersArePresent(query); !ok {
|
||||
@@ -100,20 +79,14 @@ func (m signedURLAuth) validate(req *http.Request) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m signedURLAuth) allRequiredParametersArePresent(query url.Values) (ok bool, err error) {
|
||||
func (m SignedURLAuthenticator) allRequiredParametersArePresent(query url.Values) (ok bool, err error) {
|
||||
// check if required query parameters exist in given request query parameters
|
||||
// OC-Signature - the computed signature - server will verify the request upon this REQUIRED
|
||||
// OC-Credential - defines the user scope (shall we use the owncloud user id here - this might leak internal data ....) REQUIRED
|
||||
// OC-Date - defined the date the url was signed (ISO 8601 UTC) REQUIRED
|
||||
// OC-Expires - defines the expiry interval in seconds (between 1 and 604800 = 7 days) REQUIRED
|
||||
// TODO OC-Verb - defines for which http verb the request is valid - defaults to GET OPTIONAL
|
||||
for _, p := range []string{
|
||||
"OC-Signature",
|
||||
"OC-Credential",
|
||||
"OC-Date",
|
||||
"OC-Expires",
|
||||
"OC-Verb",
|
||||
} {
|
||||
for _, p := range _requiredParams {
|
||||
if query.Get(p) == "" {
|
||||
return false, fmt.Errorf("required %s parameter not found", p)
|
||||
}
|
||||
@@ -122,19 +95,19 @@ func (m signedURLAuth) allRequiredParametersArePresent(query url.Values) (ok boo
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m signedURLAuth) requestMethodMatches(meth string, query url.Values) (ok bool, err error) {
|
||||
func (m SignedURLAuthenticator) requestMethodMatches(meth string, query url.Values) (ok bool, err error) {
|
||||
// check if given url query parameter OC-Verb matches given request method
|
||||
if !strings.EqualFold(meth, query.Get("OC-Verb")) {
|
||||
if !strings.EqualFold(meth, query.Get(_paramOCVerb)) {
|
||||
return false, errors.New("required OC-Verb parameter did not match request method")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m signedURLAuth) requestMethodIsAllowed(meth string) (ok bool, err error) {
|
||||
func (m SignedURLAuthenticator) requestMethodIsAllowed(meth string) (ok bool, err error) {
|
||||
// check if given request method is allowed
|
||||
methodIsAllowed := false
|
||||
for _, am := range m.preSignedURLConfig.AllowedHTTPMethods {
|
||||
for _, am := range m.PreSignedURLConfig.AllowedHTTPMethods {
|
||||
if strings.EqualFold(meth, am) {
|
||||
methodIsAllowed = true
|
||||
break
|
||||
@@ -147,14 +120,15 @@ func (m signedURLAuth) requestMethodIsAllowed(meth string) (ok bool, err error)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
func (m signedURLAuth) urlIsExpired(query url.Values, now func() time.Time) (expired bool, err error) {
|
||||
|
||||
func (m SignedURLAuthenticator) urlIsExpired(query url.Values, now func() time.Time) (expired bool, err error) {
|
||||
// check if url is expired by checking if given date (OC-Date) + expires in seconds (OC-Expires) is after now
|
||||
validFrom, err := time.Parse(time.RFC3339, query.Get("OC-Date"))
|
||||
validFrom, err := time.Parse(time.RFC3339, query.Get(_paramOCDate))
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
requestExpiry, err := time.ParseDuration(query.Get("OC-Expires") + "s")
|
||||
requestExpiry, err := time.ParseDuration(query.Get(_paramOCExpires) + "s")
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -164,20 +138,20 @@ func (m signedURLAuth) urlIsExpired(query url.Values, now func() time.Time) (exp
|
||||
return !(now().After(validFrom) && now().Before(validTo)), nil
|
||||
}
|
||||
|
||||
func (m signedURLAuth) signatureIsValid(req *http.Request) (ok bool, err error) {
|
||||
func (m SignedURLAuthenticator) signatureIsValid(req *http.Request) (ok bool, err error) {
|
||||
u := revactx.ContextMustGetUser(req.Context())
|
||||
signingKey, err := m.getSigningKey(req.Context(), u.Id.OpaqueId)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("could not retrieve signing key")
|
||||
m.Logger.Error().Err(err).Msg("could not retrieve signing key")
|
||||
return false, err
|
||||
}
|
||||
if len(signingKey) == 0 {
|
||||
m.logger.Error().Err(err).Msg("signing key empty")
|
||||
m.Logger.Error().Err(err).Msg("signing key empty")
|
||||
return false, err
|
||||
}
|
||||
q := req.URL.Query()
|
||||
signature := q.Get("OC-Signature")
|
||||
q.Del("OC-Signature")
|
||||
signature := q.Get(_paramOCSignature)
|
||||
q.Del(_paramOCSignature)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
url := req.URL.String()
|
||||
if !req.URL.IsAbs() {
|
||||
@@ -187,7 +161,7 @@ func (m signedURLAuth) signatureIsValid(req *http.Request) (ok bool, err error)
|
||||
return m.createSignature(url, signingKey) == signature, nil
|
||||
}
|
||||
|
||||
func (m signedURLAuth) createSignature(url string, signingKey []byte) string {
|
||||
func (m SignedURLAuthenticator) createSignature(url string, signingKey []byte) string {
|
||||
// the oc10 signature check: $hash = \hash_pbkdf2("sha512", $url, $signingKey, 10000, 64, false);
|
||||
// - sets the length of the output string to 64
|
||||
// - sets raw output to false -> if raw_output is FALSE length corresponds to twice the byte-length of the derived key (as every byte of the key is returned as two hexits).
|
||||
@@ -197,8 +171,8 @@ func (m signedURLAuth) createSignature(url string, signingKey []byte) string {
|
||||
return hex.EncodeToString(hash)
|
||||
}
|
||||
|
||||
func (m signedURLAuth) getSigningKey(ctx context.Context, ocisID string) ([]byte, error) {
|
||||
res, err := m.store.Read(ctx, &storesvc.ReadRequest{
|
||||
func (m SignedURLAuthenticator) getSigningKey(ctx context.Context, ocisID string) ([]byte, error) {
|
||||
res, err := m.Store.Read(ctx, &storesvc.ReadRequest{
|
||||
Options: &storemsg.ReadOptions{
|
||||
Database: "proxy",
|
||||
Table: "signing-keys",
|
||||
@@ -206,8 +180,44 @@ func (m signedURLAuth) getSigningKey(ctx context.Context, ocisID string) ([]byte
|
||||
Key: ocisID,
|
||||
})
|
||||
if err != nil || len(res.Records) < 1 {
|
||||
return []byte{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Records[0].Value, nil
|
||||
}
|
||||
|
||||
// Authenticate implements the authenticator interface to authenticate requests via signed URL auth.
|
||||
func (m SignedURLAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
|
||||
if !m.shouldServe(r) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
user, _, err := m.UserProvider.GetUserByClaims(r.Context(), "username", r.URL.Query().Get(_paramOCCredential), true)
|
||||
if err != nil {
|
||||
m.Logger.Error().
|
||||
Err(err).
|
||||
Str("authenticator", "signed_url").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("Could not get user by claim")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ctx := revactx.ContextSetUser(r.Context(), user)
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
if err := m.validate(r); err != nil {
|
||||
m.Logger.Error().
|
||||
Err(err).
|
||||
Str("authenticator", "signed_url").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("Could not get user by claim")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
m.Logger.Debug().
|
||||
Str("authenticator", "signed_url").
|
||||
Str("path", r.URL.Path).
|
||||
Msg("successfully authenticated request")
|
||||
return r, true
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSignedURLAuth_shouldServe(t *testing.T) {
|
||||
pua := signedURLAuth{}
|
||||
pua := SignedURLAuthenticator{}
|
||||
tests := []struct {
|
||||
url string
|
||||
enabled bool
|
||||
@@ -20,7 +20,7 @@ func TestSignedURLAuth_shouldServe(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
pua.preSignedURLConfig.Enabled = tt.enabled
|
||||
pua.PreSignedURLConfig.Enabled = tt.enabled
|
||||
r := httptest.NewRequest("", tt.url, nil)
|
||||
result := pua.shouldServe(r)
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestSignedURLAuth_shouldServe(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignedURLAuth_allRequiredParametersPresent(t *testing.T) {
|
||||
pua := signedURLAuth{}
|
||||
pua := SignedURLAuthenticator{}
|
||||
baseURL := "https://example.com/example.jpg?"
|
||||
tests := []struct {
|
||||
params string
|
||||
@@ -54,7 +54,7 @@ func TestSignedURLAuth_allRequiredParametersPresent(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignedURLAuth_requestMethodMatches(t *testing.T) {
|
||||
pua := signedURLAuth{}
|
||||
pua := SignedURLAuthenticator{}
|
||||
tests := []struct {
|
||||
method string
|
||||
url string
|
||||
@@ -75,7 +75,7 @@ func TestSignedURLAuth_requestMethodMatches(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignedURLAuth_requestMethodIsAllowed(t *testing.T) {
|
||||
pua := signedURLAuth{}
|
||||
pua := SignedURLAuthenticator{}
|
||||
tests := []struct {
|
||||
method string
|
||||
allowed []string
|
||||
@@ -89,7 +89,7 @@ func TestSignedURLAuth_requestMethodIsAllowed(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
pua.preSignedURLConfig.AllowedHTTPMethods = tt.allowed
|
||||
pua.PreSignedURLConfig.AllowedHTTPMethods = tt.allowed
|
||||
ok, _ := pua.requestMethodIsAllowed(tt.method)
|
||||
|
||||
if ok != tt.expected {
|
||||
@@ -99,7 +99,7 @@ func TestSignedURLAuth_requestMethodIsAllowed(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignedURLAuth_urlIsExpired(t *testing.T) {
|
||||
pua := signedURLAuth{}
|
||||
pua := SignedURLAuthenticator{}
|
||||
nowFunc := func() time.Time {
|
||||
t, _ := time.Parse(time.RFC3339, "2020-02-02T12:30:00.000Z")
|
||||
return t
|
||||
@@ -126,7 +126,7 @@ func TestSignedURLAuth_urlIsExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSignedURLAuth_createSignature(t *testing.T) {
|
||||
pua := signedURLAuth{}
|
||||
pua := SignedURLAuthenticator{}
|
||||
expected := "27d2ebea381384af3179235114801dcd00f91e46f99fca72575301cf3948101d"
|
||||
s := pua.createSignature("something", []byte("somerandomkey"))
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
"config" : {
|
||||
"platform": {
|
||||
"php": "7.4"
|
||||
}
|
||||
},
|
||||
"allow-plugins": {
|
||||
"composer/package-versions-deprecated": true
|
||||
}
|
||||
},
|
||||
"require": {
|
||||
"behat/behat": "^3.9",
|
||||
|
||||
Reference in New Issue
Block a user