refactor proxy code

I refactored the proxy so that we execute the routing before the
authentication middleware. This is necessary so that we can determine
which routes are considered unprotected i.e. which routes don't need
authentication.
This commit is contained in:
David Christofas
2022-08-25 16:56:16 +02:00
committed by Ralf Haferkamp
parent 48b0425fed
commit 4d4f3a16e1
4 changed files with 307 additions and 262 deletions

View File

@@ -26,6 +26,7 @@ import (
"github.com/owncloud/ocis/v2/services/proxy/pkg/metrics"
"github.com/owncloud/ocis/v2/services/proxy/pkg/middleware"
"github.com/owncloud/ocis/v2/services/proxy/pkg/proxy"
"github.com/owncloud/ocis/v2/services/proxy/pkg/router"
"github.com/owncloud/ocis/v2/services/proxy/pkg/server/debug"
proxyHTTP "github.com/owncloud/ocis/v2/services/proxy/pkg/server/http"
"github.com/owncloud/ocis/v2/services/proxy/pkg/tracing"
@@ -211,6 +212,8 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config)
oidcHTTPClient,
),
router.Middleware(cfg.PolicySelector, cfg.Policies, logger),
middleware.Authentication(
authenticators,
middleware.CredentialsByUserAgent(cfg.AuthMiddleware.CredentialsByUserAgent),

View File

@@ -6,21 +6,17 @@ import (
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
"time"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"go-micro.dev/v4/selector"
"go.opentelemetry.io/otel/attribute"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/registry"
pkgtrace "github.com/owncloud/ocis/v2/ocis-pkg/tracing"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
"github.com/owncloud/ocis/v2/services/proxy/pkg/proxy/policy"
"github.com/owncloud/ocis/v2/services/proxy/pkg/router"
proxytracing "github.com/owncloud/ocis/v2/services/proxy/pkg/tracing"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
@@ -45,7 +41,11 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
logger: options.Logger,
config: options.Config,
}
rp.Director = rp.directorSelectionDirector
rp.Director = func(r *http.Request) {
fn := router.DirectorFunc(r.Context())
fn(r)
}
// equals http.DefaultTransport except TLSClientConfig
rp.Transport = &http.Transport{
@@ -64,193 +64,9 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
InsecureSkipVerify: options.Config.InsecureBackends, //nolint:gosec
},
}
if options.Config.PolicySelector == nil {
firstPolicy := options.Config.Policies[0].Name
rp.logger.Warn().Str("policy", firstPolicy).Msg("policy-selector not configured. Will always use first policy")
options.Config.PolicySelector = &config.PolicySelector{
Static: &config.StaticSelectorConf{
Policy: firstPolicy,
},
}
}
rp.logger.Debug().
Interface("selector_config", options.Config.PolicySelector).
Msg("loading policy-selector")
policySelector, err := policy.LoadSelector(options.Config.PolicySelector)
if err != nil {
rp.logger.Fatal().Err(err).Msg("Could not load policy-selector")
}
rp.PolicySelector = policySelector
for _, pol := range options.Config.Policies {
for _, route := range pol.Routes {
rp.logger.Debug().Str("fwd: ", route.Endpoint)
if route.Backend == "" && route.Service == "" {
rp.logger.Fatal().Interface("route", route).Msg("neither Backend nor Service is set")
}
uri, err2 := url.Parse(route.Backend)
if err2 != nil {
rp.logger.
Fatal(). // fail early on misconfiguration
Err(err2).
Str("backend", route.Backend).
Msg("malformed url")
}
// here the backend is used as a uri
rp.AddHost(pol.Name, uri, route)
}
}
return rp
}
func (p *MultiHostReverseProxy) directorSelectionDirector(r *http.Request) {
pol, err := p.PolicySelector(r)
if err != nil {
p.logger.Error().Err(err).Msg("Error while selecting pol")
return
}
if _, ok := p.Directors[pol]; !ok {
p.logger.
Error().
Str("policy", pol).
Msg("policy is not configured")
return
}
method := ""
// find matching director
for _, rt := range config.RouteTypes {
var handler func(string, url.URL) bool
switch rt {
case config.QueryRoute:
handler = p.queryRouteMatcher
case config.RegexRoute:
handler = p.regexRouteMatcher
case config.PrefixRoute:
fallthrough
default:
handler = p.prefixRouteMatcher
}
if p.Directors[pol][rt][r.Method] != nil {
// use specific method
method = r.Method
}
for endpoint := range p.Directors[pol][rt][method] {
if handler(endpoint, *r.URL) {
p.logger.Debug().
Str("policy", pol).
Str("method", r.Method).
Str("prefix", endpoint).
Str("path", r.URL.Path).
Str("routeType", string(rt)).
Msg("director found")
p.Directors[pol][rt][method][endpoint](r)
return
}
}
}
// override default director with root. If any
switch {
case p.Directors[pol][config.PrefixRoute][method]["/"] != nil:
// try specific method
p.Directors[pol][config.PrefixRoute][method]["/"](r)
return
case p.Directors[pol][config.PrefixRoute][""]["/"] != nil:
// fallback to unspecific method
p.Directors[pol][config.PrefixRoute][""]["/"](r)
return
}
p.logger.
Warn().
Str("policy", pol).
Str("path", r.URL.Path).
Msg("no director found")
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// AddHost undocumented
func (p *MultiHostReverseProxy) AddHost(policy string, target *url.URL, rt config.Route) {
targetQuery := target.RawQuery
if p.Directors[policy] == nil {
p.Directors[policy] = make(map[config.RouteType]map[string]map[string]func(req *http.Request))
}
routeType := config.DefaultRouteType
if rt.Type != "" {
routeType = rt.Type
}
if p.Directors[policy][routeType] == nil {
p.Directors[policy][routeType] = make(map[string]map[string]func(req *http.Request))
}
if p.Directors[policy][routeType][rt.Method] == nil {
p.Directors[policy][routeType][rt.Method] = make(map[string]func(req *http.Request))
}
reg := registry.GetRegistry()
sel := selector.NewSelector(selector.Registry(reg))
p.Directors[policy][routeType][rt.Method][rt.Endpoint] = func(req *http.Request) {
if rt.Service != "" {
// select next node
next, err := sel.Select(rt.Service)
if err != nil {
fmt.Println(fmt.Errorf("could not select %s service from the registry: %v", rt.Service, err))
return // TODO error? fallback to target.Host & Scheme?
}
node, err := next()
if err != nil {
fmt.Println(fmt.Errorf("could not select next node for service %s: %v", rt.Service, err))
return // TODO error? fallback to target.Host & Scheme?
}
req.URL.Host = node.Address
req.URL.Scheme = node.Metadata["protocol"] // TODO check property exists?
} else {
req.URL.Host = target.Host
req.URL.Scheme = target.Scheme
}
// Apache deployments host addresses need to match on req.Host and req.URL.Host
// see https://stackoverflow.com/questions/34745654/golang-reverseproxy-with-apache2-sni-hostname-error
if rt.ApacheVHost {
req.Host = target.Host
}
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}
}
func (p *MultiHostReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
@@ -271,33 +87,3 @@ func (p *MultiHostReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request
p.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
}
func (p MultiHostReverseProxy) queryRouteMatcher(endpoint string, target url.URL) bool {
u, _ := url.Parse(endpoint)
if !strings.HasPrefix(target.Path, u.Path) || endpoint == "/" {
return false
}
q := u.Query()
if len(q) == 0 {
return false
}
tq := target.Query()
for k := range q {
if q.Get(k) != tq.Get(k) {
return false
}
}
return true
}
func (p *MultiHostReverseProxy) regexRouteMatcher(pattern string, target url.URL) bool {
matched, err := regexp.MatchString(pattern, target.String())
if err != nil {
p.logger.Warn().Err(err).Str("pattern", pattern).Msg("regex with pattern failed")
}
return matched
}
func (p *MultiHostReverseProxy) prefixRouteMatcher(prefix string, target url.URL) bool {
return strings.HasPrefix(target.Path, prefix) && prefix != "/"
}

View File

@@ -0,0 +1,261 @@
package router
import (
"context"
"net/http"
"net/url"
"regexp"
"strings"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/registry"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
"github.com/owncloud/ocis/v2/services/proxy/pkg/proxy/policy"
"go-micro.dev/v4/selector"
)
const directorCtxKey string = "director"
func Middleware(policySelector *config.PolicySelector, policies []config.Policy, logger log.Logger) func(http.Handler) http.Handler {
router := New(policySelector, policies, logger)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fn := router.Route(r)
next.ServeHTTP(w, r.WithContext(SetDirectorFunc(r.Context(), fn)))
})
}
}
func (rt Router) Route(r *http.Request) func(*http.Request) {
pol, err := rt.policySelector(r)
if err != nil {
rt.logger.Error().Err(err).Msg("Error while selecting pol")
return nil
}
if _, ok := rt.directors[pol]; !ok {
rt.logger.
Error().
Str("policy", pol).
Msg("policy is not configured")
return nil
}
method := ""
// find matching director
for _, rtype := range config.RouteTypes {
var handler func(string, url.URL) bool
switch rtype {
case config.QueryRoute:
handler = queryRouteMatcher
case config.RegexRoute:
handler = rt.regexRouteMatcher
case config.PrefixRoute:
fallthrough
default:
handler = prefixRouteMatcher
}
if rt.directors[pol][rtype][r.Method] != nil {
// use specific method
method = r.Method
}
for endpoint := range rt.directors[pol][rtype][method] {
if handler(endpoint, *r.URL) {
rt.logger.Debug().
Str("policy", pol).
Str("method", r.Method).
Str("prefix", endpoint).
Str("path", r.URL.Path).
Str("routeType", string(rtype)).
Msg("director found")
return rt.directors[pol][rtype][method][endpoint]
}
}
}
// override default director with root. If any
switch {
case rt.directors[pol][config.PrefixRoute][method]["/"] != nil:
// try specific method
return rt.directors[pol][config.PrefixRoute][method]["/"]
case rt.directors[pol][config.PrefixRoute][""]["/"] != nil:
// fallback to unspecific method
return rt.directors[pol][config.PrefixRoute][""]["/"]
}
rt.logger.
Warn().
Str("policy", pol).
Str("path", r.URL.Path).
Msg("no director found")
return nil
}
func New(policySelector *config.PolicySelector, policies []config.Policy, logger log.Logger) Router {
if policySelector == nil {
firstPolicy := policies[0].Name
logger.Warn().Str("policy", firstPolicy).Msg("policy-selector not configured. Will always use first policy")
policySelector = &config.PolicySelector{
Static: &config.StaticSelectorConf{
Policy: firstPolicy,
},
}
}
logger.Debug().
Interface("selector_config", policySelector).
Msg("loading policy-selector")
selector, err := policy.LoadSelector(policySelector)
if err != nil {
logger.Fatal().Err(err).Msg("Could not load policy-selector")
}
r := Router{
directors: make(map[string]map[config.RouteType]map[string]map[string]func(req *http.Request)),
policySelector: selector,
}
for _, pol := range policies {
for _, route := range pol.Routes {
logger.Debug().Str("fwd: ", route.Endpoint)
if route.Backend == "" && route.Service == "" {
logger.Fatal().Interface("route", route).Msg("neither Backend nor Service is set")
}
uri, err2 := url.Parse(route.Backend)
if err2 != nil {
logger.
Fatal(). // fail early on misconfiguration
Err(err2).
Str("backend", route.Backend).
Msg("malformed url")
}
// here the backend is used as a uri
r.addHost(pol.Name, uri, route)
}
}
return r
}
type Router struct {
logger log.Logger
directors map[string]map[config.RouteType]map[string]map[string]func(req *http.Request)
policySelector policy.Selector
}
func (rt Router) addHost(policy string, target *url.URL, route config.Route) {
targetQuery := target.RawQuery
if rt.directors[policy] == nil {
rt.directors[policy] = make(map[config.RouteType]map[string]map[string]func(req *http.Request))
}
routeType := config.DefaultRouteType
if route.Type != "" {
routeType = route.Type
}
if rt.directors[policy][routeType] == nil {
rt.directors[policy][routeType] = make(map[string]map[string]func(req *http.Request))
}
if rt.directors[policy][routeType][route.Method] == nil {
rt.directors[policy][routeType][route.Method] = make(map[string]func(req *http.Request))
}
reg := registry.GetRegistry()
sel := selector.NewSelector(selector.Registry(reg))
rt.directors[policy][routeType][route.Method][route.Endpoint] = func(req *http.Request) {
if route.Service != "" {
// select next node
next, err := sel.Select(route.Service)
if err != nil {
rt.logger.Error().Err(err).
Str("service", route.Service).
Msg("could not select service from the registry")
return // TODO error? fallback to target.Host & Scheme?
}
node, err := next()
if err != nil {
rt.logger.Error().Err(err).
Str("service", route.Service).
Msg("could not select next node")
return // TODO error? fallback to target.Host & Scheme?
}
req.URL.Host = node.Address
req.URL.Scheme = node.Metadata["protocol"] // TODO check property exists?
} else {
req.URL.Host = target.Host
req.URL.Scheme = target.Scheme
}
// Apache deployments host addresses need to match on req.Host and req.URL.Host
// see https://stackoverflow.com/questions/34745654/golang-reverseproxy-with-apache2-sni-hostname-error
if route.ApacheVHost {
req.Host = target.Host
}
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func queryRouteMatcher(endpoint string, target url.URL) bool {
u, _ := url.Parse(endpoint)
if !strings.HasPrefix(target.Path, u.Path) || endpoint == "/" {
return false
}
q := u.Query()
if len(q) == 0 {
return false
}
tq := target.Query()
for k := range q {
if q.Get(k) != tq.Get(k) {
return false
}
}
return true
}
func (rt Router) regexRouteMatcher(pattern string, target url.URL) bool {
matched, err := regexp.MatchString(pattern, target.String())
if err != nil {
rt.logger.Warn().Err(err).Str("pattern", pattern).Msg("regex with pattern failed")
}
return matched
}
func prefixRouteMatcher(prefix string, target url.URL) bool {
return strings.HasPrefix(target.Path, prefix) && prefix != "/"
}
func SetDirectorFunc(parent context.Context, fn func(*http.Request)) context.Context {
return context.WithValue(parent, directorCtxKey, fn)
}
// DirectorFunc gets the director function from the context.
func DirectorFunc(ctx context.Context) func(*http.Request) {
val := ctx.Value(directorCtxKey)
return val.(func(*http.Request))
}

View File

@@ -1,13 +1,10 @@
package proxy
package router
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/services/proxy/pkg/config/defaults"
)
@@ -19,7 +16,6 @@ type matchertest struct {
func TestPrefixRouteMatcher(t *testing.T) {
cfg := defaults.DefaultConfig()
cfg.Policies = defaults.DefaultPolicies()
p := NewMultiHostReverseProxy(Config(cfg))
table := []matchertest{
{endpoint: "/foobar", target: "/foobar/baz/some/url", matches: true},
@@ -28,7 +24,7 @@ func TestPrefixRouteMatcher(t *testing.T) {
for _, test := range table {
u, _ := url.Parse(test.target)
matched := p.prefixRouteMatcher(test.endpoint, *u)
matched := prefixRouteMatcher(test.endpoint, *u)
if matched != test.matches {
t.Errorf("PrefixRouteMatcher returned %t expected %t for endpoint: %s and target %s",
matched, test.matches, test.endpoint, u.String())
@@ -39,7 +35,6 @@ func TestPrefixRouteMatcher(t *testing.T) {
func TestQueryRouteMatcher(t *testing.T) {
cfg := defaults.DefaultConfig()
cfg.Policies = defaults.DefaultPolicies()
p := NewMultiHostReverseProxy(Config(cfg))
table := []matchertest{
{endpoint: "/foobar?parameter=true", target: "/foobar/baz/some/url?parameter=true", matches: true},
@@ -56,7 +51,7 @@ func TestQueryRouteMatcher(t *testing.T) {
for _, test := range table {
u, _ := url.Parse(test.target)
matched := p.queryRouteMatcher(test.endpoint, *u)
matched := queryRouteMatcher(test.endpoint, *u)
if matched != test.matches {
t.Errorf("QueryRouteMatcher returned %t expected %t for endpoint: %s and target %s",
matched, test.matches, test.endpoint, u.String())
@@ -67,7 +62,7 @@ func TestQueryRouteMatcher(t *testing.T) {
func TestRegexRouteMatcher(t *testing.T) {
cfg := defaults.DefaultConfig()
cfg.Policies = defaults.DefaultPolicies()
p := NewMultiHostReverseProxy(Config(cfg))
rt := New(cfg.PolicySelector, cfg.Policies, log.NewLogger())
table := []matchertest{
{endpoint: ".*some\\/url.*parameter=true", target: "/foobar/baz/some/url?parameter=true", matches: true},
@@ -76,7 +71,7 @@ func TestRegexRouteMatcher(t *testing.T) {
for _, test := range table {
u, _ := url.Parse(test.target)
matched := p.regexRouteMatcher(test.endpoint, *u)
matched := rt.regexRouteMatcher(test.endpoint, *u)
if matched != test.matches {
t.Errorf("RegexRouteMatcher returned %t expected %t for endpoint: %s and target %s",
matched, test.matches, test.endpoint, u.String())
@@ -104,34 +99,34 @@ func TestSingleJoiningSlash(t *testing.T) {
}
}
func TestDirectorSelectionDirector(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok")
}))
defer svr.Close()
p := NewMultiHostReverseProxy(Config(&config.Config{
PolicySelector: &config.PolicySelector{
Static: &config.StaticSelectorConf{
Policy: "default",
},
},
}))
p.AddHost("default", &url.URL{Host: "ocdav"}, config.Route{Type: config.PrefixRoute, Method: "", Endpoint: "/dav", Backend: "ocdav"})
p.AddHost("default", &url.URL{Host: "ocis-webdav"}, config.Route{Type: config.PrefixRoute, Method: "REPORT", Endpoint: "/dav", Backend: "ocis-webdav"})
table := []matchertest{
{method: "PROPFIND", endpoint: "/dav/files/demo/", target: "ocdav"},
{method: "REPORT", endpoint: "/dav/files/demo/", target: "ocis-webdav"},
}
for _, test := range table {
r := httptest.NewRequest(test.method, "/dav/files/demo/", nil)
p.directorSelectionDirector(r)
if r.URL.Host != test.target {
t.Errorf("TestDirectorSelectionDirector got host %s expected %s", r.Host, test.target)
}
}
}
// func TestDirectorSelectionDirector(t *testing.T) {
//
// svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// fmt.Fprintf(w, "ok")
// }))
// defer svr.Close()
//
// p := NewMultiHostReverseProxy(Config(&config.Config{
// PolicySelector: &config.PolicySelector{
// Static: &config.StaticSelectorConf{
// Policy: "default",
// },
// },
// }))
// p.AddHost("default", &url.URL{Host: "ocdav"}, config.Route{Type: config.PrefixRoute, Method: "", Endpoint: "/dav", Backend: "ocdav"})
// p.AddHost("default", &url.URL{Host: "ocis-webdav"}, config.Route{Type: config.PrefixRoute, Method: "REPORT", Endpoint: "/dav", Backend: "ocis-webdav"})
//
// table := []matchertest{
// {method: "PROPFIND", endpoint: "/dav/files/demo/", target: "ocdav"},
// {method: "REPORT", endpoint: "/dav/files/demo/", target: "ocis-webdav"},
// }
//
// for _, test := range table {
// r := httptest.NewRequest(test.method, "/dav/files/demo/", nil)
// p.directorSelectionDirector(r)
// if r.URL.Host != test.target {
// t.Errorf("TestDirectorSelectionDirector got host %s expected %s", r.Host, test.target)
//
// }
// }
// }