diff --git a/changelog/unreleased/fix-graph-education-createschool.md b/changelog/unreleased/fix-graph-education-createschool.md new file mode 100644 index 0000000000..4d867a4556 --- /dev/null +++ b/changelog/unreleased/fix-graph-education-createschool.md @@ -0,0 +1,6 @@ +Bugfix: Check school number for duplicates before adding a school + +We fixed an issue that allowed to create two schools with the same school number + +https://github.com/owncloud/ocis/pull/7351 +https://github.com/owncloud/enterprise/issues/6051 diff --git a/services/graph/pkg/identity/ldap_education_class_test.go b/services/graph/pkg/identity/ldap_education_class_test.go index b8f481b142..b6f34cc8a3 100644 --- a/services/graph/pkg/identity/ldap_education_class_test.go +++ b/services/graph/pkg/identity/ldap_education_class_test.go @@ -72,9 +72,7 @@ func TestGetEducationClasses(t *testing.T) { lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetEducationClasses(context.Background()) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") lm = &mocks.Client{} lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) @@ -156,7 +154,7 @@ func TestGetEducationClass(t *testing.T) { if tt.expectedItemNotFound { assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } else { assert.Nil(t, err) assert.Equal(t, "Math", class.GetDisplayName()) @@ -228,7 +226,7 @@ func TestDeleteEducationClass(t *testing.T) { if tt.expectedItemNotFound { lm.AssertNumberOfCalls(t, "Del", 0) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } else { assert.Nil(t, err) } @@ -301,7 +299,7 @@ func TestGetEducationClassMembers(t *testing.T) { if tt.expectedItemNotFound { lm.AssertNumberOfCalls(t, "Search", 1) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } else { lm.AssertNumberOfCalls(t, "Search", 2) assert.Nil(t, err) diff --git a/services/graph/pkg/identity/ldap_education_school.go b/services/graph/pkg/identity/ldap_education_school.go index 6274c2e8f0..fe02f4b262 100644 --- a/services/graph/pkg/identity/ldap_education_school.go +++ b/services/graph/pkg/identity/ldap_education_school.go @@ -49,7 +49,11 @@ const ( const ldapDateFormat = "20060102150405Z0700" -var errNotSet = errors.New("Attribute not set") +var ( + errNotSet = errors.New("Attribute not set") + errSchoolNameExists = errorcode.New(errorcode.NameAlreadyExists, "A school with that name is already present") + errSchoolNumberExists = errorcode.New(errorcode.NameAlreadyExists, "A school with that number is already present") +) func defaultEducationConfig() educationConfig { return educationConfig{ @@ -116,7 +120,18 @@ func (i *LDAP) CreateEducationSchool(ctx context.Context, school libregraph.Educ return nil, ErrReadOnly } - // Here we should verify that the school number is not already used + // Check that the school number is not already used + _, err := i.getSchoolByNumber(school.GetSchoolNumber()) + switch err { + case nil: + logger.Debug().Err(errSchoolNumberExists).Str("schoolNumber", school.GetSchoolNumber()).Msg("duplicate school number") + return nil, errSchoolNumberExists + case ErrNotFound: + break + default: + logger.Error().Err(err).Str("schoolNumber", school.GetSchoolNumber()).Msg("error looking up school by number") + return nil, errorcode.New(errorcode.GeneralException, "error looking up school by number") + } attributeTypeAndValue := ldap.AttributeTypeAndValue{ Type: i.educationConfig.schoolAttributeMap.displayName, @@ -141,7 +156,7 @@ func (i *LDAP) CreateEducationSchool(ctx context.Context, school libregraph.Educ logger.Debug().Err(err).Msg("error adding school") if errors.As(err, &lerr) { if lerr.ResultCode == ldap.LDAPResultEntryAlreadyExists { - err = errorcode.New(errorcode.NameAlreadyExists, lerr.Error()) + err = errSchoolNameExists } } return nil, err @@ -156,7 +171,7 @@ func (i *LDAP) CreateEducationSchool(ctx context.Context, school libregraph.Educ } // UpdateEducationSchoolOperation contains the logic for which update operation to apply to a school -func (i *LDAP) UpdateEducationSchoolOperation( +func (i *LDAP) updateEducationSchoolOperation( schoolUpdate libregraph.EducationSchool, currentSchool libregraph.EducationSchool, ) schoolUpdateOperation { @@ -216,7 +231,7 @@ func (i *LDAP) updateDisplayName(ctx context.Context, dn string, providedDisplay logger.Debug().Err(err).Msg("error updating school name") if errors.As(err, &lerr) { if lerr.ResultCode == ldap.LDAPResultEntryAlreadyExists { - err = errorcode.New(errorcode.NameAlreadyExists, lerr.Error()) + err = errSchoolNameExists } } return err @@ -233,11 +248,9 @@ func (i *LDAP) updateSchoolProperties(ctx context.Context, dn string, currentSch mr := ldap.NewModifyRequest(dn, nil) if updatedSchoolNumber, ok := updatedSchool.GetSchoolNumberOk(); ok { if *updatedSchoolNumber != "" && currentSchool.GetSchoolNumber() != *updatedSchoolNumber { - _, err := i.getSchoolByNumberOrID(*updatedSchoolNumber) + _, err := i.getSchoolByNumber(*updatedSchoolNumber) if err == nil { - errmsg := fmt.Sprintf("school number '%s' already exists", *updatedSchoolNumber) - err = fmt.Errorf(errmsg) - return err + return errSchoolNumberExists } mr.Replace(i.educationConfig.schoolAttributeMap.schoolNumber, []string{*updatedSchoolNumber}) } @@ -255,13 +268,7 @@ func (i *LDAP) updateSchoolProperties(ctx context.Context, dn string, currentSch } if err := i.conn.Modify(mr); err != nil { - var lerr *ldap.Error logger.Debug().Err(err).Msg("error updating school number") - if errors.As(err, &lerr) { - if lerr.ResultCode == ldap.LDAPResultEntryAlreadyExists { - err = errorcode.New(errorcode.NameAlreadyExists, lerr.Error()) - } - } return err } @@ -282,7 +289,7 @@ func (i *LDAP) UpdateEducationSchool(ctx context.Context, numberOrID string, sch } currentSchool := i.createSchoolModelFromLDAP(e) - switch i.UpdateEducationSchoolOperation(school, *currentSchool) { + switch i.updateEducationSchoolOperation(school, *currentSchool) { case tooManyValues: return nil, fmt.Errorf("school name and school number cannot be updated in the same request") case schoolUnchanged: @@ -635,17 +642,12 @@ func (i *LDAP) RemoveClassFromEducationSchool(ctx context.Context, schoolNumberO } func (i *LDAP) getSchoolByDN(dn string) (*ldap.Entry, error) { - attrs := []string{ - i.educationConfig.schoolAttributeMap.displayName, - i.educationConfig.schoolAttributeMap.id, - i.educationConfig.schoolAttributeMap.schoolNumber, - } filter := fmt.Sprintf("(objectClass=%s)", i.educationConfig.schoolObjectClass) if i.educationConfig.schoolFilter != "" { filter = fmt.Sprintf("(&%s(%s))", filter, i.educationConfig.schoolFilter) } - return i.getEntryByDN(dn, attrs, filter) + return i.getEntryByDN(dn, i.getEducationSchoolAttrTypes(), filter) } func (i *LDAP) getSchoolByNumberOrID(numberOrID string) (*ldap.Entry, error) { @@ -660,6 +662,16 @@ func (i *LDAP) getSchoolByNumberOrID(numberOrID string) (*ldap.Entry, error) { return i.getSchoolByFilter(filter) } +func (i *LDAP) getSchoolByNumber(schoolNumber string) (*ldap.Entry, error) { + schoolNumber = ldap.EscapeFilter(schoolNumber) + filter := fmt.Sprintf( + "(%s=%s)", + i.educationConfig.schoolAttributeMap.schoolNumber, + schoolNumber, + ) + return i.getSchoolByFilter(filter) +} + func (i *LDAP) getSchoolByFilter(filter string) (*ldap.Entry, error) { filter = fmt.Sprintf("(&%s(objectClass=%s)%s)", i.educationConfig.schoolFilter, diff --git a/services/graph/pkg/identity/ldap_education_school_test.go b/services/graph/pkg/identity/ldap_education_school_test.go index d0e06353b7..520cdc5c90 100644 --- a/services/graph/pkg/identity/ldap_education_school_test.go +++ b/services/graph/pkg/identity/ldap_education_school_test.go @@ -2,6 +2,7 @@ package identity import ( "context" + "errors" "testing" "time" @@ -9,6 +10,7 @@ import ( libregraph "github.com/owncloud/libre-graph-api-go" "github.com/owncloud/ocis/v2/services/graph/mocks" "github.com/owncloud/ocis/v2/services/graph/pkg/config" + "github.com/owncloud/ocis/v2/services/graph/pkg/service/v0/errorcode" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -67,45 +69,135 @@ var ( ) func TestCreateEducationSchool(t *testing.T) { - lm := &mocks.Client{} - - ldapSchoolAddRequestMatcher := func(ar *ldap.AddRequest) bool { - if ar.DN != "ou=Test School," { - return false - } - for _, attr := range ar.Attributes { - if attr.Type == "ocEducationSchoolTerminationTimestamp" { + tests := []struct { + name string + schoolNumber string + schoolName string + expectedError error + }{ + { + name: "Create a Education School succeeds", + schoolNumber: "0123", + schoolName: "Test School", + expectedError: nil, + }, { + name: "Create a Education School with a duplicated Schoolnumber fails with an error", + schoolNumber: "0666", + schoolName: "Test School", + expectedError: errorcode.New(errorcode.NameAlreadyExists, "A school with that number is already present"), + }, { + name: "Create a Education School with a duplicated Name fails with an error", + schoolNumber: "0123", + schoolName: "Existing Test School", + expectedError: errorcode.New(errorcode.NameAlreadyExists, "A school with that name is already present"), + }, { + name: "Create a Education School fails, when there is a backend error", + schoolNumber: "1111", + schoolName: "Test School", + expectedError: errorcode.New(errorcode.GeneralException, "error looking up school by number"), + }, + } + for _, tt := range tests { + lm := &mocks.Client{} + ldapSchoolGoodAddRequestMatcher := func(ar *ldap.AddRequest) bool { + if ar.DN != "ou=Test School," { return false } + for _, attr := range ar.Attributes { + if attr.Type == "ocEducationSchoolTerminationTimestamp" { + return false + } + } + return true + } + lm.On("Add", mock.MatchedBy(ldapSchoolGoodAddRequestMatcher)).Return(nil) + + ldapExistingSchoolAddRequestMatcher := func(ar *ldap.AddRequest) bool { + if ar.DN == "ou=Existing Test School," { + return true + } + return false + } + lm.On("Add", mock.MatchedBy(ldapExistingSchoolAddRequestMatcher)).Return(ldap.NewError(ldap.LDAPResultEntryAlreadyExists, errors.New(""))) + + schoolNumberSearchRequest := &ldap.SearchRequest{ + BaseDN: "", + Scope: 2, + SizeLimit: 1, + Filter: "(&(objectClass=ocEducationSchool)(ocEducationSchoolNumber=0123))", + Attributes: []string{"ou", "owncloudUUID", "ocEducationSchoolNumber", "ocEducationSchoolTerminationTimestamp"}, + Controls: []ldap.Control(nil), + } + lm.On("Search", schoolNumberSearchRequest). + Return( + &ldap.SearchResult{ + Entries: []*ldap.Entry{}, + }, + nil) + existingSchoolNumberSearchRequest := &ldap.SearchRequest{ + BaseDN: "", + Scope: 2, + SizeLimit: 1, + Filter: "(&(objectClass=ocEducationSchool)(ocEducationSchoolNumber=0666))", + Attributes: []string{"ou", "owncloudUUID", "ocEducationSchoolNumber", "ocEducationSchoolTerminationTimestamp"}, + Controls: []ldap.Control(nil), + } + lm.On("Search", existingSchoolNumberSearchRequest). + Return( + &ldap.SearchResult{ + Entries: []*ldap.Entry{schoolEntry}, + }, + nil) + schoolNumberSearchRequestError := &ldap.SearchRequest{ + BaseDN: "", + Scope: 2, + SizeLimit: 1, + Filter: "(&(objectClass=ocEducationSchool)(ocEducationSchoolNumber=1111))", + Attributes: []string{"ou", "owncloudUUID", "ocEducationSchoolNumber", "ocEducationSchoolTerminationTimestamp"}, + Controls: []ldap.Control(nil), + } + lm.On("Search", schoolNumberSearchRequestError). + Return( + &ldap.SearchResult{ + Entries: []*ldap.Entry{}, + }, + ldap.NewError(ldap.LDAPResultOther, errors.New("some error"))) + schoolLookupAfterCreate := &ldap.SearchRequest{ + BaseDN: "ou=Test School,", + Scope: 0, + SizeLimit: 1, + Filter: "(objectClass=ocEducationSchool)", + Attributes: []string{"ou", "owncloudUUID", "ocEducationSchoolNumber", "ocEducationSchoolTerminationTimestamp"}, + Controls: []ldap.Control(nil), + } + lm.On("Search", schoolLookupAfterCreate). + Return( + &ldap.SearchResult{ + Entries: []*ldap.Entry{schoolEntry}, + }, + nil) + b, err := getMockedBackend(lm, eduConfig, &logger) + assert.Nil(t, err) + assert.NotEqual(t, "", b.educationConfig.schoolObjectClass) + + school := libregraph.NewEducationSchool() + school.SetDisplayName(tt.schoolName) + school.SetSchoolNumber(tt.schoolNumber) + school.SetId("abcd-defg") + res_school, err := b.CreateEducationSchool(context.Background(), *school) + if tt.expectedError == nil { + assert.Nil(t, err) + lm.AssertNumberOfCalls(t, "Add", 1) + assert.NotNil(t, res_school) + assert.Equal(t, res_school.GetDisplayName(), school.GetDisplayName()) + assert.Equal(t, res_school.GetId(), school.GetId()) + assert.Equal(t, res_school.GetSchoolNumber(), school.GetSchoolNumber()) + assert.False(t, res_school.HasTerminationDate()) + } else { + assert.Equal(t, err, tt.expectedError) + assert.Nil(t, res_school) } - return true } - - lm.On("Add", mock.MatchedBy(ldapSchoolAddRequestMatcher)).Return(nil) - - lm.On("Search", mock.Anything). - Return( - &ldap.SearchResult{ - Entries: []*ldap.Entry{schoolEntry}, - }, - nil) - - b, err := getMockedBackend(lm, eduConfig, &logger) - assert.Nil(t, err) - assert.NotEqual(t, "", b.educationConfig.schoolObjectClass) - school := libregraph.NewEducationSchool() - school.SetDisplayName("Test School") - school.SetSchoolNumber("0123") - school.SetId("abcd-defg") - res_school, err := b.CreateEducationSchool(context.Background(), *school) - lm.AssertNumberOfCalls(t, "Add", 1) - lm.AssertNumberOfCalls(t, "Search", 1) - assert.Nil(t, err) - assert.NotNil(t, res_school) - assert.Equal(t, res_school.GetDisplayName(), school.GetDisplayName()) - assert.Equal(t, res_school.GetId(), school.GetId()) - assert.Equal(t, res_school.GetSchoolNumber(), school.GetSchoolNumber()) - assert.False(t, res_school.HasTerminationDate()) } func TestUpdateEducationSchoolTerminationDate(t *testing.T) { @@ -218,7 +310,7 @@ func TestUpdateEducationSchoolOperation(t *testing.T) { SchoolNumber: &tt.schoolNumber, } - operation := b.UpdateEducationSchoolOperation(schoolUpdate, currentSchool) + operation := b.updateEducationSchoolOperation(schoolUpdate, currentSchool) assert.Equal(t, tt.expectedOperation, operation) } } @@ -285,7 +377,7 @@ func TestDeleteEducationSchool(t *testing.T) { if tt.expectedItemNotFound { lm.AssertNumberOfCalls(t, "Del", 0) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } else { assert.Nil(t, err) } @@ -349,7 +441,7 @@ func TestGetEducationSchool(t *testing.T) { if tt.expectedItemNotFound { assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } else { assert.Nil(t, err) assert.Equal(t, "Test School", school.GetDisplayName()) @@ -502,7 +594,7 @@ func TestRemoveMemberFromEducationSchool(t *testing.T) { err = b.RemoveUserFromEducationSchool(context.Background(), "abcd-defg", "does-not-exist") lm.AssertNumberOfCalls(t, "Search", 2) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) err = b.RemoveUserFromEducationSchool(context.Background(), "abcd-defg", "abcd-defg") lm.AssertNumberOfCalls(t, "Search", 4) lm.AssertNumberOfCalls(t, "Modify", 1) @@ -613,7 +705,7 @@ func TestRemoveClassFromEducationSchool(t *testing.T) { err = b.RemoveClassFromEducationSchool(context.Background(), "abcd-defg", "does-not-exist") lm.AssertNumberOfCalls(t, "Search", 2) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) err = b.RemoveClassFromEducationSchool(context.Background(), "abcd-defg", "abcd-defg") lm.AssertNumberOfCalls(t, "Search", 4) lm.AssertNumberOfCalls(t, "Modify", 1) diff --git a/services/graph/pkg/identity/ldap_education_user_test.go b/services/graph/pkg/identity/ldap_education_user_test.go index 8b0e9fff53..763f27b4f3 100644 --- a/services/graph/pkg/identity/ldap_education_user_test.go +++ b/services/graph/pkg/identity/ldap_education_user_test.go @@ -139,7 +139,7 @@ func TestDeleteEducationUser(t *testing.T) { lm.AssertNumberOfCalls(t, "Search", 2) lm.AssertNumberOfCalls(t, "Del", 1) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } func TestGetEducationUser(t *testing.T) { @@ -157,7 +157,7 @@ func TestGetEducationUser(t *testing.T) { _, err = b.GetEducationUser(context.Background(), "xxxx-xxxx") lm.AssertNumberOfCalls(t, "Search", 2) assert.NotNil(t, err) - assert.Equal(t, "itemNotFound", err.Error()) + assert.Equal(t, "itemNotFound: not found", err.Error()) } func TestGetEducationUsers(t *testing.T) { diff --git a/services/graph/pkg/identity/ldap_group_test.go b/services/graph/pkg/identity/ldap_group_test.go index e3613967b8..79431d4760 100644 --- a/services/graph/pkg/identity/ldap_group_test.go +++ b/services/graph/pkg/identity/ldap_group_test.go @@ -59,34 +59,22 @@ func TestGetGroup(t *testing.T) { b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetGroup(context.Background(), "group", nil) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") _, err = b.GetGroup(context.Background(), "group", queryParamExpand) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") _, err = b.GetGroup(context.Background(), "group", queryParamSelect) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") // Mock an empty Search Result lm = &mocks.Client{} lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) _, err = b.GetGroup(context.Background(), "group", nil) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") _, err = b.GetGroup(context.Background(), "group", queryParamExpand) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") _, err = b.GetGroup(context.Background(), "group", queryParamSelect) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") // Mock an invalid Search Result lm = &mocks.Client{} @@ -95,17 +83,11 @@ func TestGetGroup(t *testing.T) { }, nil) b, _ = getMockedBackend(lm, lconfig, &logger) g, err := b.GetGroup(context.Background(), "group", nil) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") g, err = b.GetGroup(context.Background(), "group", queryParamExpand) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") g, err = b.GetGroup(context.Background(), "group", queryParamSelect) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") // Mock a valid Search Result lm = &mocks.Client{} @@ -240,9 +222,7 @@ func TestGetGroups(t *testing.T) { lm.On("Search", mock.Anything).Return(nil, ldap.NewError(ldap.LDAPResultOperationsError, errors.New("mock"))) b, _ := getMockedBackend(lm, lconfig, &logger) _, err := b.GetGroups(context.Background(), url.Values{}) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") lm = &mocks.Client{} lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) diff --git a/services/graph/pkg/identity/ldap_test.go b/services/graph/pkg/identity/ldap_test.go index fcd5a551ae..908862f819 100644 --- a/services/graph/pkg/identity/ldap_test.go +++ b/services/graph/pkg/identity/ldap_test.go @@ -205,14 +205,10 @@ func TestGetUser(t *testing.T) { } _, err = b.GetUser(context.Background(), "fred", odataReqDefault) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") _, err = b.GetUser(context.Background(), "fred", odataReqExpand) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") // Mock an empty Search Result lm = &mocks.Client{} @@ -221,14 +217,10 @@ func TestGetUser(t *testing.T) { &ldap.SearchResult{}, nil) b, _ = getMockedBackend(lm, lconfig, &logger) _, err = b.GetUser(context.Background(), "fred", odataReqDefault) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") _, err = b.GetUser(context.Background(), "fred", odataReqExpand) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") // Mock a valid Search Result lm = &mocks.Client{} @@ -265,9 +257,7 @@ func TestGetUser(t *testing.T) { b, _ = getMockedBackend(lm, lconfig, &logger) _, err = b.GetUser(context.Background(), "invalid", nil) - if err == nil || err.Error() != "itemNotFound" { - t.Errorf("Expected 'itemNotFound' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "itemNotFound:") } func TestGetUsers(t *testing.T) { @@ -283,9 +273,7 @@ func TestGetUsers(t *testing.T) { b, _ := getMockedBackend(lm, lconfig, &logger) _, err = b.GetUsers(context.Background(), odataReqDefault) - if err == nil || err.Error() != "generalException" { - t.Errorf("Expected 'generalException' got '%s'", err.Error()) - } + assert.ErrorContains(t, err, "generalException:") lm = &mocks.Client{} lm.On("Search", mock.Anything).Return(&ldap.SearchResult{}, nil) diff --git a/services/graph/pkg/service/v0/educationschools.go b/services/graph/pkg/service/v0/educationschools.go index eb7b10d7e7..8a176ef205 100644 --- a/services/graph/pkg/service/v0/educationschools.go +++ b/services/graph/pkg/service/v0/educationschools.go @@ -148,7 +148,7 @@ func (g Graph) PatchEducationSchool(w http.ResponseWriter, r *http.Request) { if school, err = g.identityEducationBackend.UpdateEducationSchool(r.Context(), schoolID, *school); err != nil { logger.Debug().Err(err).Interface("school", school).Msg("could not update school: backend error") - errorcode.GeneralException.Render(w, r, http.StatusInternalServerError, err.Error()) + errorcode.RenderError(w, r, err) return } diff --git a/services/graph/pkg/service/v0/errorcode/errorcode.go b/services/graph/pkg/service/v0/errorcode/errorcode.go index e15289f327..f235759c95 100644 --- a/services/graph/pkg/service/v0/errorcode/errorcode.go +++ b/services/graph/pkg/service/v0/errorcode/errorcode.go @@ -1,3 +1,4 @@ +// Package errorcode allows to deal with graph error codes package errorcode import ( @@ -13,6 +14,7 @@ import ( // ErrorCode defines code as used in MS Graph - see https://docs.microsoft.com/en-us/graph/errors?context=graph%2Fapi%2F1.0&view=graph-rest-1.0 type ErrorCode int +// Error defines a custom error struct, containing and MS Graph error code an a textual error message type Error struct { errorCode ErrorCode msg string @@ -75,6 +77,7 @@ var errorCodes = [...]string{ "preconditionFailed", } +// New constructs a new errorcode.Error func New(e ErrorCode, msg string) Error { return Error{ errorCode: e, @@ -82,7 +85,7 @@ func New(e ErrorCode, msg string) Error { } } -// Render writes an Graph ErrorObject to the response writer +// Render writes an Graph ErrorCode object to the response writer func (e ErrorCode) Render(w http.ResponseWriter, r *http.Request, status int, msg string) { innererror := map[string]interface{}{ "date": time.Now().UTC().Format(time.RFC3339), @@ -100,6 +103,7 @@ func (e ErrorCode) Render(w http.ResponseWriter, r *http.Request, status int, ms render.JSON(w, r, resp) } +// Render writes an Graph Error object to the response writer func (e Error) Render(w http.ResponseWriter, r *http.Request) { var status int switch e.errorCode { @@ -115,12 +119,18 @@ func (e Error) Render(w http.ResponseWriter, r *http.Request) { e.errorCode.Render(w, r, status, e.msg) } +// String returns the string corresponding to the ErrorCode func (e ErrorCode) String() string { return errorCodes[e] } +// Error return the concatenation of the error string and optinal message func (e Error) Error() string { - return errorCodes[e.errorCode] + errString := errorCodes[e.errorCode] + if e.msg != "" { + errString += ": " + e.msg + } + return errString } // RenderError render the Graph Error based on a code or default one