From d8c538f7205ae8d37e33c2992f4342da68626da1 Mon Sep 17 00:00:00 2001 From: David Christofas Date: Wed, 8 Jul 2020 19:14:32 +0200 Subject: [PATCH] refactor to make the code more readable Signed-off-by: David Christofas --- pkg/server/glauth/handler.go | 206 +++++++++++++++++++++++++---------- 1 file changed, 150 insertions(+), 56 deletions(-) diff --git a/pkg/server/glauth/handler.go b/pkg/server/glauth/handler.go index b06d1da745..1f38021ccc 100644 --- a/pkg/server/glauth/handler.go +++ b/pkg/server/glauth/handler.go @@ -17,6 +17,13 @@ import ( "github.com/owncloud/ocis-pkg/v2/log" ) +type queryType string + +const ( + usersQuery queryType = "users" + groupsQuery queryType = "groups" +) + type ocisHandler struct { as accounts.AccountsService log log.Logger @@ -27,18 +34,30 @@ func (h ocisHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAP bindDN = strings.ToLower(bindDN) baseDN := strings.ToLower("," + h.cfg.Backend.BaseDN) - h.log.Debug().Str("binddn", bindDN).Str("basedn", h.cfg.Backend.BaseDN).Interface("src", conn.RemoteAddr()).Msg("Bind request") + h.log.Debug(). + Str("binddn", bindDN). + Str("basedn", h.cfg.Backend.BaseDN). + Interface("src", conn.RemoteAddr()). + Msg("Bind request") stats.Frontend.Add("bind_reqs", 1) // parse the bindDN - ensure that the bindDN ends with the BaseDN if !strings.HasSuffix(bindDN, baseDN) { - h.log.Error().Str("binddn", bindDN).Str("basedn", h.cfg.Backend.BaseDN).Interface("src", conn.RemoteAddr()).Msg("BindDN not part of our BaseDN") + h.log.Error(). + Str("binddn", bindDN). + Str("basedn", h.cfg.Backend.BaseDN). + Interface("src", conn.RemoteAddr()). + Msg("BindDN not part of our BaseDN") return ldap.LDAPResultInvalidCredentials, nil } parts := strings.Split(strings.TrimSuffix(bindDN, baseDN), ",") if len(parts) > 2 { - h.log.Error().Str("binddn", bindDN).Int("numparts", len(parts)).Interface("src", conn.RemoteAddr()).Msg("BindDN should have only one or two parts") + h.log.Error(). + Str("binddn", bindDN). + Int("numparts", len(parts)). + Interface("src", conn.RemoteAddr()). + Msg("BindDN should have only one or two parts") return ldap.LDAPResultInvalidCredentials, nil } userName := strings.TrimPrefix(parts[0], "cn=") @@ -52,12 +71,19 @@ func (h ocisHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAP Query: fmt.Sprintf("login eq '%s' and password eq '%s'", userName, bindSimplePw), }) if err != nil { - h.log.Error().Str("username", userName).Str("binddn", bindDN).Interface("src", conn.RemoteAddr()).Msg("Login failed") + h.log.Error(). + Str("username", userName). + Str("binddn", bindDN). + Interface("src", conn.RemoteAddr()). + Msg("Login failed") return ldap.LDAPResultInvalidCredentials, nil } stats.Frontend.Add("bind_successes", 1) - h.log.Debug().Str("binddn", bindDN).Interface("src", conn.RemoteAddr()).Msg("Bind success") + h.log.Debug(). + Str("binddn", bindDN). + Interface("src", conn.RemoteAddr()). + Msg("Bind success") return ldap.LDAPResultSuccess, nil } @@ -65,21 +91,32 @@ func (h ocisHandler) Search(bindDN string, searchReq ldap.SearchRequest, conn ne bindDN = strings.ToLower(bindDN) baseDN := strings.ToLower("," + h.cfg.Backend.BaseDN) searchBaseDN := strings.ToLower(searchReq.BaseDN) - h.log.Debug().Str("binddn", bindDN).Str("basedn", h.cfg.Backend.BaseDN).Str("filter", searchReq.Filter).Interface("src", conn.RemoteAddr()).Msg("Search request") + h.log.Debug(). + Str("binddn", bindDN). + Str("basedn", h.cfg.Backend.BaseDN). + Str("filter", searchReq.Filter). + Interface("src", conn.RemoteAddr()). + Msg("Search request") stats.Frontend.Add("search_reqs", 1) // validate the user is authenticated and has appropriate access if len(bindDN) < 1 { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("search error: Anonymous BindDN not allowed %s", bindDN) + return ldap.ServerSearchResult{ + ResultCode: ldap.LDAPResultInsufficientAccessRights, + }, fmt.Errorf("search error: Anonymous BindDN not allowed %s", bindDN) } if !strings.HasSuffix(bindDN, baseDN) { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("search error: BindDN %s not in our BaseDN %s", bindDN, h.cfg.Backend.BaseDN) + return ldap.ServerSearchResult{ + ResultCode: ldap.LDAPResultInsufficientAccessRights, + }, fmt.Errorf("search error: BindDN %s not in our BaseDN %s", bindDN, h.cfg.Backend.BaseDN) } if !strings.HasSuffix(searchBaseDN, h.cfg.Backend.BaseDN) { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("search error: search BaseDN %s is not in our BaseDN %s", searchBaseDN, h.cfg.Backend.BaseDN) + return ldap.ServerSearchResult{ + ResultCode: ldap.LDAPResultInsufficientAccessRights, + }, fmt.Errorf("search error: search BaseDN %s is not in our BaseDN %s", searchBaseDN, h.cfg.Backend.BaseDN) } - qtype := "" + var qtype queryType = "" query := "" var err error if searchReq.Filter == "(&)" { // see Absolute True and False Filters in https://tools.ietf.org/html/rfc4526#section-2 @@ -88,55 +125,106 @@ func (h ocisHandler) Search(bindDN string, searchReq ldap.SearchRequest, conn ne var cf *ber.Packet cf, err = ldap.CompileFilter(searchReq.Filter) if err != nil { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", searchReq.Filter) + return ldap.ServerSearchResult{ + ResultCode: ldap.LDAPResultOperationsError, + }, fmt.Errorf("Search Error: error parsing filter: %s", searchReq.Filter) } qtype, query, err = parseFilter(cf) if err != nil { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", searchReq.Filter) + return ldap.ServerSearchResult{ + ResultCode: ldap.LDAPResultOperationsError, + }, fmt.Errorf("Search Error: error parsing filter: %s", searchReq.Filter) } } entries := []*ldap.Entry{} - h.log.Debug().Str("binddn", bindDN).Str("basedn", h.cfg.Backend.BaseDN).Str("filter", searchReq.Filter).Str("qtype", qtype).Str("query", query).Msg("parsed query") - if qtype == "users" { + h.log.Debug(). + Str("binddn", bindDN). + Str("basedn", h.cfg.Backend.BaseDN). + Str("filter", searchReq.Filter). + Str("qtype", string(qtype)). + Str("query", query). + Msg("parsed query") + switch qtype { + case usersQuery: accounts, err := h.as.ListAccounts(context.TODO(), &accounts.ListAccountsRequest{ Query: query, }) if err != nil { - h.log.Error().Err(err).Str("binddn", bindDN).Str("basedn", h.cfg.Backend.BaseDN).Str("filter", searchReq.Filter).Str("query", query).Interface("src", conn.RemoteAddr()).Msg("Could not list accounts") - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("search error: error getting users") - } - for i := range accounts.Accounts { - attrs := []*ldap.EntryAttribute{ - {Name: "objectClass", Values: []string{"posixAccount", "inetOrgPerson", "organizationalPerson", "Person", "top"}}, - {Name: "cn", Values: []string{accounts.Accounts[i].PreferredName}}, - {Name: "uid", Values: []string{accounts.Accounts[i].PreferredName}}, - {Name: "sn", Values: []string{accounts.Accounts[i].PreferredName}}, // must be set for a valid person - } - if accounts.Accounts[i].DisplayName != "" { - attrs = append(attrs, &ldap.EntryAttribute{Name: "displayName", Values: []string{accounts.Accounts[i].DisplayName}}) - } - if accounts.Accounts[i].Mail != "" { - attrs = append(attrs, &ldap.EntryAttribute{Name: "mail", Values: []string{accounts.Accounts[i].Mail}}) - } - if accounts.Accounts[i].UidNumber != 0 { // TODO no root? - attrs = append(attrs, &ldap.EntryAttribute{Name: "uidnumber", Values: []string{strconv.FormatInt(accounts.Accounts[i].UidNumber, 10)}}) - } - if accounts.Accounts[i].GidNumber != 0 { - attrs = append(attrs, &ldap.EntryAttribute{Name: "gidnumber", Values: []string{strconv.FormatInt(accounts.Accounts[i].GidNumber, 10)}}) - } - if accounts.Accounts[i].Description != "" { - attrs = append(attrs, &ldap.EntryAttribute{Name: "description", Values: []string{accounts.Accounts[i].Description}}) - } + h.log.Error(). + Err(err). + Str("binddn", bindDN). + Str("basedn", h.cfg.Backend.BaseDN). + Str("filter", searchReq.Filter). + Str("query", query). + Interface("src", conn.RemoteAddr()). + Msg("Could not list accounts") - dn := fmt.Sprintf("%s=%s,%s=%s,%s", h.cfg.Backend.NameFormat, accounts.Accounts[i].PreferredName, h.cfg.Backend.GroupFormat, "users", h.cfg.Backend.BaseDN) - entries = append(entries, &ldap.Entry{DN: dn, Attributes: attrs}) + return ldap.ServerSearchResult{ + ResultCode: ldap.LDAPResultOperationsError, + }, errors.New("search error: error getting users") } + entries = append(entries, h.mapAccounts(accounts.Accounts)...) } stats.Frontend.Add("search_successes", 1) - h.log.Debug().Str("binddn", bindDN).Str("basedn", h.cfg.Backend.BaseDN).Str("filter", searchReq.Filter).Interface("src", conn.RemoteAddr()).Msg("AP: Search OK") - return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil + h.log.Debug(). + Str("binddn", bindDN). + Str("basedn", h.cfg.Backend.BaseDN). + Str("filter", searchReq.Filter). + Interface("src", conn.RemoteAddr()). + Msg("AP: Search OK") + + return ldap.ServerSearchResult{ + Entries: entries, + Referrals: []string{}, + Controls: []ldap.Control{}, + ResultCode: ldap.LDAPResultSuccess, + }, nil +} + +func attribute(name string, values ...string) *ldap.EntryAttribute { + return &ldap.EntryAttribute{ + Name: name, + Values: values, + } +} + +func (h ocisHandler) mapAccounts(accounts []*accounts.Account) []*ldap.Entry { + var entries []*ldap.Entry + for _, acc := range accounts { + attrs := []*ldap.EntryAttribute{ + attribute("objectClass", "posixAccount", "inetOrgPerson", "organizationalPerson", "Person", "top"), + attribute("cn", acc.PreferredName), + attribute("uid", acc.PreferredName), + attribute("sn", acc.PreferredName), + } + if acc.DisplayName != "" { + attrs = append(attrs, attribute("displayName", acc.DisplayName)) + } + if acc.Mail != "" { + attrs = append(attrs, attribute("mail", acc.Mail)) + } + if acc.UidNumber != 0 { // TODO no root? + attrs = append(attrs, attribute("uidnumber", strconv.FormatInt(acc.UidNumber, 10))) + } + if acc.GidNumber != 0 { + attrs = append(attrs, attribute("gidnumber", strconv.FormatInt(acc.GidNumber, 10))) + } + if acc.Description != "" { + attrs = append(attrs, attribute("description", acc.Description)) + } + + dn := fmt.Sprintf("%s=%s,%s=%s,%s", + h.cfg.Backend.NameFormat, + acc.PreferredName, + h.cfg.Backend.GroupFormat, + "users", + h.cfg.Backend.BaseDN, + ) + entries = append(entries, &ldap.Entry{DN: dn, Attributes: attrs}) + } + return entries } // LDAP filters might ask for grouips and users at the same time, eg. @@ -150,7 +238,7 @@ func (h ocisHandler) Search(bindDN string, searchReq ldap.SearchRequest, conn ne // "" not determined // "users" // "groups" -func parseFilter(f *ber.Packet) (qtype string, q string, err error) { +func parseFilter(f *ber.Packet) (qtype queryType, q string, err error) { switch ldap.FilterMap[f.Tag] { case "Equality Match": if len(f.Children) != 2 { @@ -164,33 +252,34 @@ func parseFilter(f *ber.Packet) (qtype string, q string, err error) { case "objectclass": switch value { case "posixaccount", "shadowaccount", "users", "person", "inetorgperson", "organizationalperson": - qtype = "users" + qtype = usersQuery case "posixgroup", "groups": - qtype = "groups" + qtype = groupsQuery default: qtype = "" } - return qtype, "", nil case "cn", "uid": - return "", fmt.Sprintf("preferred_name eq '%s'", strings.ReplaceAll(value, "'", "''")), nil + q = fmt.Sprintf("preferred_name eq '%s'", escapeValue(value)) case "mail": - return "", fmt.Sprintf("mail eq '%s'", strings.ReplaceAll(value, "'", "''")), nil + q = fmt.Sprintf("mail eq '%s'", escapeValue(value)) case "displayname": - return "", fmt.Sprintf("display_name eq '%s'", strings.ReplaceAll(value, "'", "''")), nil + q = fmt.Sprintf("display_name eq '%s'", escapeValue(value)) case "uidnumber": - return "", fmt.Sprintf("uid_number eq '%s'", strings.ReplaceAll(value, "'", "''")), nil + q = fmt.Sprintf("uid_number eq '%s'", escapeValue(value)) case "gidnumber": - return "", fmt.Sprintf("gid_number eq '%s'", strings.ReplaceAll(value, "'", "''")), nil + q = fmt.Sprintf("gid_number eq '%s'", escapeValue(value)) case "description": - return "", fmt.Sprintf("description eq '%s'", strings.ReplaceAll(value, "'", "''")), nil + q = fmt.Sprintf("description eq '%s'", escapeValue(value)) + default: + err = fmt.Errorf("filter by %s not implmented", attribute) } - return "", "", fmt.Errorf("filter by %s not implmented", attribute) + return case "And": subQueries := []string{} for _, child := range f.Children { var subQuery string - var qt string + var qt queryType qt, subQuery, err = parseFilter(child) if err != nil { return "", "", err @@ -209,7 +298,7 @@ func parseFilter(f *ber.Packet) (qtype string, q string, err error) { subQueries := []string{} for _, child := range f.Children { var subQuery string - var qt string + var qt queryType qt, subQuery, err = parseFilter(child) if err != nil { return "", "", err @@ -240,6 +329,11 @@ func parseFilter(f *ber.Packet) (qtype string, q string, err error) { return } +// escapeValue escapes all special characters in the value +func escapeValue(value string) string { + return strings.ReplaceAll(value, "'", "''") +} + func (h ocisHandler) Close(boundDN string, conn net.Conn) error { stats.Frontend.Add("closes", 1) return nil