Parse role claims (#7713)

* extract and test role claim parsing

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add failing test

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* read segmented roles claim as array and string

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* reuse more code by extracting WalkSegments

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add TestSplitWithEscaping

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* docs and error for unhandled case

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add claims test

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add missing ReadStringClaim docs

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

---------

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
Jörn Friedrich Dreyer
2023-12-04 12:18:52 +01:00
committed by GitHub
parent 81ace6dd1d
commit 23e59b5ded
5 changed files with 409 additions and 32 deletions

View File

@@ -1,5 +1,10 @@
package oidc
import (
"fmt"
"strings"
)
const (
Iss = "iss"
Sub = "sub"
@@ -12,3 +17,60 @@ const (
OwncloudUUID = "ownclouduuid"
OcisRoutingPolicy = "ocis.routing.policy"
)
// SplitWithEscaping splits s into segments using separator which can be escaped using the escape string
// See https://codereview.stackexchange.com/a/280193
func SplitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)
for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}
// WalkSegments uses the given array of segments to walk the claims and return whatever interface was found
func WalkSegments(segments []string, claims map[string]interface{}) (interface{}, error) {
i := 0
for ; i < len(segments)-1; i++ {
switch castedClaims := claims[segments[i]].(type) {
case map[string]interface{}:
claims = castedClaims
case map[interface{}]interface{}:
claims = make(map[string]interface{}, len(castedClaims))
for k, v := range castedClaims {
if s, ok := k.(string); ok {
claims[s] = v
} else {
return nil, fmt.Errorf("could not walk claims path, key '%v' is not a string", k)
}
}
default:
return nil, fmt.Errorf("unsupported type '%v'", castedClaims)
}
}
return claims[segments[i]], nil
}
// ReadStringClaim returns the string obtained by following the . seperated path in the claims
func ReadStringClaim(path string, claims map[string]interface{}) (string, error) {
// check the simple case first
value, _ := claims[path].(string)
if value != "" {
return value, nil
}
claim, err := WalkSegments(SplitWithEscaping(path, ".", "\\"), claims)
if err != nil {
return "", err
}
if value, _ = claim.(string); value != "" {
return value, nil
}
return value, fmt.Errorf("claim path '%s' not set or empty", path)
}

View File

@@ -0,0 +1,182 @@
package oidc_test
import (
"encoding/json"
"reflect"
"testing"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
)
type splitWithEscapingTest struct {
// Name of the subtest.
name string
// string to split
s string
// seperator to use
seperator string
// escape character to use for escaping
escape string
expectedParts []string
}
func (swet splitWithEscapingTest) run(t *testing.T) {
parts := oidc.SplitWithEscaping(swet.s, swet.seperator, swet.escape)
if len(swet.expectedParts) != len(parts) {
t.Errorf("mismatching length")
}
for i, v := range swet.expectedParts {
if parts[i] != v {
t.Errorf("expected part %d to be '%s', got '%s'", i, v, parts[i])
}
}
}
func TestSplitWithEscaping(t *testing.T) {
tests := []splitWithEscapingTest{
{
name: "plain claim name",
s: "roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"roles"},
},
{
name: "claim with .",
s: "my.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my", "roles"},
},
{
name: "claim with escaped .",
s: "my\\.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my.roles"},
},
{
name: "claim with escaped . left",
s: "my\\.other.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my.other", "roles"},
},
{
name: "claim with escaped . right",
s: "my.other\\.roles",
seperator: ".",
escape: "\\",
expectedParts: []string{"my", "other.roles"},
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}
type walkSegmentsTest struct {
// Name of the subtest.
name string
// path segments to walk
segments []string
// seperator to use
claims map[string]interface{}
expected interface{}
wantErr bool
}
func (wst walkSegmentsTest) run(t *testing.T) {
v, err := oidc.WalkSegments(wst.segments, wst.claims)
if err != nil && !wst.wantErr {
t.Errorf("%v", err)
}
if err == nil && wst.wantErr {
t.Errorf("expected error")
}
if !reflect.DeepEqual(v, wst.expected) {
t.Errorf("expected %v got %v", wst.expected, v)
}
}
func TestWalkSegments(t *testing.T) {
byt := []byte(`{"first":{"second":{"third":["value1","value2"]},"foo":"bar"},"fizz":"buzz"}`)
var dat map[string]interface{}
if err := json.Unmarshal(byt, &dat); err != nil {
t.Errorf("%v", err)
}
tests := []walkSegmentsTest{
{
name: "one segment, single value",
segments: []string{"first"},
claims: map[string]interface{}{
"first": "value",
},
expected: "value",
wantErr: false,
},
{
name: "one segment, array value",
segments: []string{"first"},
claims: map[string]interface{}{
"first": []string{"value1", "value2"},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
{
name: "two segments, single value",
segments: []string{"first", "second"},
claims: map[string]interface{}{
"first": map[string]interface{}{
"second": "value",
},
},
expected: "value",
wantErr: false,
},
{
name: "two segments, array value",
segments: []string{"first", "second"},
claims: map[string]interface{}{
"first": map[string]interface{}{
"second": []string{"value1", "value2"},
},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
{
name: "three segments, array value from json",
segments: []string{"first", "second", "third"},
claims: dat,
expected: []interface{}{"value1", "value2"},
wantErr: false,
},
{
name: "three segments, array value with interface key",
segments: []string{"first", "second", "third"},
claims: map[string]interface{}{
"first": map[interface{}]interface{}{
"second": map[interface{}]interface{}{
"third": []string{"value1", "value2"},
},
},
},
expected: []string{"value1", "value2"},
wantErr: false,
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}

View File

@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
"github.com/owncloud/ocis/v2/services/proxy/pkg/userroles"
@@ -43,19 +42,6 @@ type accountResolver struct {
userCS3Claim string
}
// from https://codereview.stackexchange.com/a/280193
func splitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)
for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}
func readUserIDClaim(path string, claims map[string]interface{}) (string, error) {
// happy path
value, _ := claims[path].(string)
@@ -64,7 +50,7 @@ func readUserIDClaim(path string, claims map[string]interface{}) (string, error)
}
// try splitting path at .
segments := splitWithEscaping(path, ".", "\\")
segments := oidc.SplitWithEscaping(path, ".", "\\")
subclaims := claims
lastSegment := len(segments) - 1
for i := range segments {

View File

@@ -9,6 +9,7 @@ import (
cs3 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
"github.com/cs3org/reva/v2/pkg/utils"
"github.com/owncloud/ocis/v2/ocis-pkg/middleware"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
settingssvc "github.com/owncloud/ocis/v2/protogen/gen/ocis/services/settings/v0"
"go-micro.dev/v4/metadata"
)
@@ -29,6 +30,45 @@ func NewOIDCRoleAssigner(opts ...Option) UserRoleAssigner {
}
}
func extractRoles(rolesClaim string, claims map[string]interface{}) (map[string]struct{}, error) {
claimRoles := map[string]struct{}{}
// happy path
value, _ := claims[rolesClaim].(string)
if value != "" {
claimRoles[value] = struct{}{}
return claimRoles, nil
}
claim, err := oidc.WalkSegments(oidc.SplitWithEscaping(rolesClaim, ".", "\\"), claims)
if err != nil {
return nil, err
}
switch v := claim.(type) {
case []string:
for _, cr := range v {
claimRoles[cr] = struct{}{}
}
case []interface{}:
for _, cri := range v {
cr, ok := cri.(string)
if !ok {
err := errors.New("invalid role in claims")
return nil, err
}
claimRoles[cr] = struct{}{}
}
case string:
claimRoles[v] = struct{}{}
default:
return nil, errors.New("no roles in user claims")
}
return claimRoles, nil
}
// UpdateUserRoleAssignment assigns the role "User" to the supplied user. Unless the user
// already has a different role assigned.
func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *cs3.User, claims map[string]interface{}) (*cs3.User, error) {
@@ -39,23 +79,10 @@ func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *c
return nil, err
}
claimRolesRaw, ok := claims[ra.rolesClaim].([]interface{})
if !ok {
logger.Error().Str("rolesClaim", ra.rolesClaim).Msg("No roles in user claims")
return nil, errors.New("no roles in user claims")
}
logger.Debug().Str("rolesClaim", ra.rolesClaim).Interface("rolesInClaim", claims[ra.rolesClaim]).Msg("got roles in claim")
claimRoles := map[string]struct{}{}
for _, cri := range claimRolesRaw {
cr, ok := cri.(string)
if !ok {
err := errors.New("invalid role in claims")
logger.Error().Err(err).Interface("claimValue", cri).Msg("Is not a valid string.")
return nil, err
}
claimRoles[cr] = struct{}{}
claimRoles, err := extractRoles(ra.rolesClaim, claims)
if err != nil {
logger.Error().Err(err).Msg("Error mapping role names to role ids")
return nil, err
}
if len(claimRoles) == 0 {

View File

@@ -0,0 +1,120 @@
package userroles
import (
"encoding/json"
"testing"
)
func TestExtractRolesArray(t *testing.T) {
byt := []byte(`{"roles":["a","b"]}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
if _, ok := roles["b"]; !ok {
t.Fatal("must contain 'b'")
}
}
func TestExtractRolesString(t *testing.T) {
byt := []byte(`{"roles":"a"}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
}
func TestExtractRolesPathArray(t *testing.T) {
byt := []byte(`{"sub":{"roles":["a","b"]}}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub.roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
if _, ok := roles["b"]; !ok {
t.Fatal("must contain 'b'")
}
}
func TestExtractRolesPathString(t *testing.T) {
byt := []byte(`{"sub":{"roles":"a"}}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub.roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
}
func TestExtractEscapedRolesPathString(t *testing.T) {
byt := []byte(`{"sub.roles":"a"}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub\\.roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
}
func TestNoRoles(t *testing.T) {
byt := []byte(`{"sub":{"foo":"a"}}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub.roles", claims)
if err == nil {
t.Fatal("must not find a role")
}
if len(roles) != 0 {
t.Fatal("length of roles mut be 0")
}
}