From 0dffde2e6d69da3ce574c88ae390a1144d62b780 Mon Sep 17 00:00:00 2001 From: "A.Unger" Date: Mon, 4 May 2020 14:08:19 +0200 Subject: [PATCH] added a generic cache --- pkg/cache/cache.go | 95 +++++++++++++++++++++++++++++++++ pkg/cache/cache_test.go | 67 +++++++++++++++++++++++ pkg/cache/option.go | 36 +++++++++++++ pkg/middleware/logger.go | 15 ------ pkg/middleware/openidconnect.go | 79 +++++++++++++++++++-------- 5 files changed, 254 insertions(+), 38 deletions(-) create mode 100644 pkg/cache/cache.go create mode 100644 pkg/cache/cache_test.go create mode 100644 pkg/cache/option.go delete mode 100644 pkg/middleware/logger.go diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 0000000000..51ab60fe7b --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,95 @@ +package cache + +import ( + "fmt" + "sync" + "time" +) + +// Entry represents an entry on the cache. You can type assert on V. +type Entry struct { + V interface{} + Valid bool +} + +// Cache is a barebones cache implementation. +type Cache struct { + entries map[string]map[string]Entry + size int + ttl time.Duration + m sync.Mutex +} + +// NewCache returns a new instance of Cache. +func NewCache(o ...Option) Cache { + opts := newOptions(o...) + + return Cache{ + size: opts.size, + entries: map[string]map[string]Entry{}, + } +} + +// Get gets an entry on a service `svcKey` by a give `key`. +func (c *Cache) Get(svcKey, key string) (*Entry, error) { + var value Entry + ok := true + + c.m.Lock() + defer c.m.Unlock() + + if value, ok = c.entries[svcKey][key]; !ok { + return nil, fmt.Errorf("invalid service key: `%v`", key) + } + + return &value, nil +} + +// Set sets a key / value. It lets a service add entries on a request basis. +func (c *Cache) Set(svcKey, key string, val interface{}) error { + c.m.Lock() + defer c.m.Unlock() + + if !c.fits() { + return fmt.Errorf("cache is full") + } + + if _, ok := c.entries[svcKey]; !ok { + c.entries[svcKey] = map[string]Entry{} + } + + if _, ok := c.entries[svcKey][key]; ok { + return fmt.Errorf("key `%v` already exists", key) + } + + c.entries[svcKey][key] = Entry{ + V: val, + Valid: true, + } + + return nil +} + +// Invalidate invalidates a cache Entry by key. +func (c *Cache) Invalidate(key string) error { + c.m.Lock() + defer c.m.Unlock() + + if _, ok := c.entries[key]; !ok { + return fmt.Errorf("invalid key: `%v`", key) + } + + return nil +} + +// Length returns the amount of entries. +func (c *Cache) Length(k string) int { + return len(c.entries[k]) +} + +func (c *Cache) fits() bool { + if c.size < len(c.entries) { + return false + } + return true +} diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go new file mode 100644 index 0000000000..1c6500ff54 --- /dev/null +++ b/pkg/cache/cache_test.go @@ -0,0 +1,67 @@ +package cache + +import ( + "testing" +) + +// Prevents from invalid import cycle. +type AccountsCacheEntry struct { + Email string + UUID string +} + +func TestSet(t *testing.T) { + c := NewCache( + Size(256), + ) + + err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{ + Email: "hello@foo.bar", + UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05", + }) + if err != nil { + t.Error(err) + } + + if c.Length("accounts") != 1 { + t.Errorf("expected length 1 got `%v`", len(c.entries)) + } + + item, err := c.Get("accounts", "hello@foo.bar") + if err != nil { + t.Error(err) + } + + if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok { + t.Errorf("invalid cached value type") + } else { + if cachedEntry.Email != "hello@foo.bar" { + t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email) + } + } +} + +func TestGet(t *testing.T) { + svcCache := NewCache( + Size(256), + ) + + err := svcCache.Set("accounts", "node", "0.0.0.0:1234") + if err != nil { + t.Error(err) + } + + raw, err := svcCache.Get("accounts", "node") + if err != nil { + t.Error(err) + } + + v, ok := raw.V.(string) + if !ok { + t.Errorf("invalid type on service node key") + } + + if v != "0.0.0.0:1234" { + t.Errorf("expected `0.0.0.0:1234` got `%v`", v) + } +} diff --git a/pkg/cache/option.go b/pkg/cache/option.go new file mode 100644 index 0000000000..bfee5be52c --- /dev/null +++ b/pkg/cache/option.go @@ -0,0 +1,36 @@ +package cache + +import "time" + +// Options are all the possible options. +type Options struct { + size int + ttl time.Duration +} + +// Option mutates option +type Option func(*Options) + +// Size configures the size of the cache in items. +func Size(s int) Option { + return func(o *Options) { + o.size = s + } +} + +// TTL rebuilds the cache after the configured duration. +func TTL(ttl time.Duration) Option { + return func(o *Options) { + o.ttl = ttl + } +} + +func newOptions(opts ...Option) Options { + o := Options{} + + for _, v := range opts { + v(&o) + } + + return o +} diff --git a/pkg/middleware/logger.go b/pkg/middleware/logger.go deleted file mode 100644 index 9db319f223..0000000000 --- a/pkg/middleware/logger.go +++ /dev/null @@ -1,15 +0,0 @@ -package middleware - -import ( - "net/http" -) - -// Logger undocummented -func Logger() M { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // do some logging logic here - next.ServeHTTP(w, r) - }) - } -} diff --git a/pkg/middleware/openidconnect.go b/pkg/middleware/openidconnect.go index 3bcfecd9eb..f772340533 100644 --- a/pkg/middleware/openidconnect.go +++ b/pkg/middleware/openidconnect.go @@ -4,15 +4,16 @@ import ( "context" "crypto/tls" "errors" + "fmt" "net/http" "strings" "time" oidc "github.com/coreos/go-oidc" mclient "github.com/micro/go-micro/v2/client" - "github.com/micro/go-micro/v2/registry" acc "github.com/owncloud/ocis-accounts/pkg/proto/v0" ocisoidc "github.com/owncloud/ocis-pkg/v2/oidc" + "github.com/owncloud/ocis-proxy/pkg/cache" "golang.org/x/oauth2" ) @@ -20,6 +21,9 @@ var ( // ErrInvalidToken is returned when the request token is invalid. ErrInvalidToken = errors.New("invalid or missing token") + // svcCache caches requests for given services to prevent round trips to the service + svcCache = cache.NewCache() + accountSvc = "com.owncloud.accounts" ) @@ -53,6 +57,9 @@ func OpenIDConnect(opts ...ocisoidc.Option) M { header := r.Header.Get("Authorization") path := r.URL.Path + // void call for testing purposes. + uuidFromClaims(ocisoidc.StandardClaims{}) + // Ignore request to "/konnect/v1/userinfo" as this will cause endless loop when getting userinfo // needs a better idea on how to not hardcode this if header == "" || !strings.HasPrefix(header, "Bearer ") || path == "/konnect/v1/userinfo" { @@ -107,6 +114,16 @@ func OpenIDConnect(opts ...ocisoidc.Option) M { return } + // add UUID to the request context for the handler to deal with + // void call for correct staticchecks. + _, err = uuidFromClaims(claims) + + if err != nil { + opt.Logger.Error().Err(err).Interface("account uuid", userInfo).Msg("failed to unmarshal userinfo claims") + w.WriteHeader(http.StatusInternalServerError) + return + } + opt.Logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo") // store claims in context // uses the original context, not the one with probably reduced security @@ -117,33 +134,49 @@ func OpenIDConnect(opts ...ocisoidc.Option) M { } } +// AccountsCacheEntry stores a request to the accounts service on the cache. +// this type declaration should be on each respective service. +type AccountsCacheEntry struct { + Email string + UUID string +} + +const ( + // AccountsKey declares the svcKey for the Accounts service. + AccountsKey = "accounts" + + // NodeKey declares the key that will be used to store the node address. + // It is shared between services. + NodeKey = "node" +) + // from the user claims we need to get the uuid from the accounts service func uuidFromClaims(claims ocisoidc.StandardClaims) (string, error) { - var node string - // get accounts node from micro registry - // TODO this assumes we use mdns as registry. This should be configurable for any ocis extension. - svc, err := registry.GetService(accountSvc) + entry, err := svcCache.Get(AccountsKey, claims.Email) if err != nil { - return "", err + c := acc.NewSettingsService("com.owncloud.accounts", mclient.DefaultClient) // TODO this won't work with a registry other than mdns. Look into Micro's client initialization. + resp, err := c.Get(context.Background(), &acc.Query{ + Key: "200~a54bf154-e6a5-4e96-851b-a56c9f6c1fce", // use hardcoded key... + // Email: claims.Email // depends on @jfd PR. + }) + if err != nil { + return "", err + } + + // TODO add logging info. Round trip has been made to the accounts service. + err = svcCache.Set(AccountsKey, claims.Email, resp.Payload.Account.Uuid) + if err != nil { + return "", err + } + + return resp.Key, nil } - if len(svc) > 0 { - node = svc[0].Nodes[0].Address + uuid, ok := entry.V.(string) + if !ok { + return "", fmt.Errorf("unexpected type on cache entry value. Expected string type") } - c := acc.NewSettingsService("accounts", mclient.DefaultClient) - _, err = c.Get(context.Background(), &acc.Query{ - // TODO accounts query message needs to be updated to query for multiple fields - // queries by key makes little sense as it is unknown. - Key: "73912d13-32f7-4fb6-aeb2-ea2088a3a264", - }) - if err != nil { - return "", err - } - - // by this point, rec.Payload contains the Account info. To include UUID, see: - // https://github.com/owncloud/ocis-accounts/pull/22/files#diff-b425175389864c4f9218ecd9cae80223R23 - - // return rec.GetPayload().Account.UUID, nil // depends on the aforementioned PR - return node, nil + // TODO add logging info. Read from cache. + return uuid, nil }