chore(BRIDGE-534): replace exp/maps and exp/slice with common utils & stdlib packages

This commit is contained in:
Sebastijan Zindl
2026-04-24 11:08:19 +02:00
parent 24794d804a
commit 6bf7f5a61e
12 changed files with 462 additions and 25 deletions

2
go.mod
View File

@@ -22,7 +22,6 @@ require (
github.com/urfave/cli/v2 v2.27.7
gitlab.com/c0b/go-ordered-json v0.0.0-20201030195603-febf46534d5a
go.uber.org/goleak v1.3.0
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90
golang.org/x/net v0.52.0
golang.org/x/text v0.36.0
google.golang.org/grpc v1.80.0
@@ -67,6 +66,7 @@ require (
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect

View File

@@ -10,6 +10,7 @@ import (
"github.com/ProtonMail/go-crypto/openpgp"
"github.com/ProtonMail/go-crypto/openpgp/armor"
"github.com/ProtonMail/go-proton-api/pkg/utils"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
)
@@ -148,7 +149,7 @@ func (keys Keys) Unlock(passphrase []byte, userKR *crypto.KeyRing) (*crypto.KeyR
return nil, err
}
for _, key := range Filter(keys, func(key Key) bool { return bool(key.Active) }) {
for _, key := range utils.Filter(keys, func(key Key) bool { return bool(key.Active) }) {
unlocked, err := key.Unlock(passphrase, userKR)
if err != nil {
log.WithField("KeyID", key.ID).WithError(err).Warning("Cannot unlock key")

28
pkg/utils/maps.go Normal file
View File

@@ -0,0 +1,28 @@
package utils
import (
"maps"
"slices"
)
// Keys returns a slice of keys from the map.
// Alternative to using maps.Keys which returns an iterator instead of a slice.
func Keys[M ~map[K]V, K comparable, V any](m M) []K {
keys := make([]K, 0, len(m))
return slices.AppendSeq(
keys,
maps.Keys(m),
)
}
// Values returns a slice of values from the map.
// Alternative to using maps.Values which returns an iterator instead of a slice.
func Values[M ~map[K]V, K comparable, V any](m M) []V {
values := make([]V, 0, len(m))
return slices.AppendSeq(
values,
maps.Values(m),
)
}

241
pkg/utils/maps_test.go Normal file
View File

@@ -0,0 +1,241 @@
package utils
import (
"testing"
"github.com/stretchr/testify/require"
)
type keysTestCase[K comparable, V any] struct {
name string
expected []K
builder func() map[K]V
}
type valuesTestCase[K comparable, V any] struct {
name string
expected []V
builder func() map[K]V
}
func TestMaps_Keys_String(t *testing.T) {
testCases := []keysTestCase[string, int]{
{
name: "empty",
builder: func() map[string]int {
return make(map[string]int, 0)
},
expected: []string{},
},
{
name: "valid",
builder: func() map[string]int {
return map[string]int{
"a": 1,
"b": 2,
"c": 3,
}
},
expected: []string{"a", "b", "c"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provided := tc.builder()
result := Keys(provided)
require.Len(t, result, len(tc.expected))
for _, v := range result {
require.Contains(t, tc.expected, v)
}
})
}
}
func TestMaps_Keys_Int(t *testing.T) {
testCases := []keysTestCase[int, int]{
{
name: "empty",
builder: func() map[int]int {
return make(map[int]int, 0)
},
expected: []int{},
},
{
name: "valid",
builder: func() map[int]int {
return map[int]int{
1: 2,
3: 4,
5: 6,
}
},
expected: []int{1, 3, 5},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provided := tc.builder()
result := Keys(provided)
require.Len(t, result, len(tc.expected))
for _, v := range result {
require.Contains(t, tc.expected, v)
}
})
}
}
func TestMaps_Keys_Struct(t *testing.T) {
type customStruct struct {
ID string
Name string
}
testCases := []keysTestCase[customStruct, int]{
{
name: "empty",
builder: func() map[customStruct]int {
return make(map[customStruct]int, 0)
},
expected: []customStruct{},
},
{
name: "valid",
builder: func() map[customStruct]int {
return map[customStruct]int{
{ID: "1", Name: "test"}: 1,
}
},
expected: []customStruct{{ID: "1", Name: "test"}},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provided := tc.builder()
result := Keys(provided)
require.Len(t, result, len(tc.expected))
for _, v := range result {
require.Contains(t, tc.expected, v)
}
})
}
}
func TestMaps_Values_String(t *testing.T) {
testCases := []valuesTestCase[int, string]{
{
name: "empty",
builder: func() map[int]string {
return make(map[int]string, 0)
},
expected: []string{},
},
{
name: "valid",
builder: func() map[int]string {
return map[int]string{
1: "a",
3: "b",
5: "c",
}
},
expected: []string{"a", "b", "c"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provided := tc.builder()
result := Values(provided)
require.Len(t, result, len(tc.expected))
for _, v := range result {
require.Contains(t, tc.expected, v)
}
})
}
}
func TestMaps_Values_Int(t *testing.T) {
testCases := []valuesTestCase[int, int]{
{
name: "empty",
builder: func() map[int]int {
return make(map[int]int, 0)
},
expected: []int{},
},
{
name: "valid",
builder: func() map[int]int {
return map[int]int{
1: 2,
3: 4,
5: 6,
}
},
expected: []int{2, 4, 6},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provided := tc.builder()
result := Values(provided)
require.Len(t, result, len(tc.expected))
for _, v := range result {
require.Contains(t, tc.expected, v)
}
})
}
}
func TestMaps_Values_Struct(t *testing.T) {
type customStruct struct {
ID string
Name string
}
testCases := []valuesTestCase[int, customStruct]{
{
name: "empty",
builder: func() map[int]customStruct {
return make(map[int]customStruct, 0)
},
expected: []customStruct{},
},
{
name: "valid",
builder: func() map[int]customStruct {
return map[int]customStruct{
1: {ID: "1", Name: "test"},
}
},
expected: []customStruct{{ID: "1", Name: "test"}},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
provided := tc.builder()
result := Values(provided)
require.Len(t, result, len(tc.expected))
for _, v := range result {
require.Contains(t, tc.expected, v)
}
})
}
}

View File

@@ -1,4 +1,4 @@
package proton
package utils
import "slices"

165
pkg/utils/slices_test.go Normal file
View File

@@ -0,0 +1,165 @@
package utils
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
type filterTestCase[T any] struct {
name string
input []T
keep func(T) bool
expected []T
}
func TestSlice_Filter_Int(t *testing.T) {
testCases := []filterTestCase[int]{
{
name: "empty",
input: []int{},
keep: func(_ int) bool {
return true
},
expected: []int{},
},
{
name: "all",
input: []int{1, 2, 3},
keep: func(_ int) bool {
return true
},
expected: []int{1, 2, 3},
},
{
name: "none",
input: []int{1, 2, 3},
keep: func(_ int) bool {
return false
},
expected: []int{},
},
{
name: "only one",
input: []int{1, 2, 3},
keep: func(i int) bool {
return i == 2
},
expected: []int{2},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := Filter(tc.input, tc.keep)
require.Equal(t, tc.expected, result)
require.Len(t, result, len(tc.expected))
})
}
}
func TestSlice_Filter_String(t *testing.T) {
testCases := []filterTestCase[string]{
{
name: "empty",
input: []string{},
keep: func(_ string) bool {
return true
},
expected: []string{},
},
{
name: "all",
input: []string{"a", "b", "c"},
keep: func(_ string) bool {
return true
},
expected: []string{"a", "b", "c"},
},
{
name: "none",
input: []string{"a", "b", "c"},
keep: func(_ string) bool {
return false
},
expected: []string{},
},
{
name: "only one",
input: []string{"a", "b", "c"},
keep: func(s string) bool {
return s == "b"
},
expected: []string{"b"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := Filter(tc.input, tc.keep)
require.Equal(t, tc.expected, result)
require.Len(t, result, len(tc.expected))
})
}
}
func TestSlice_Filter_Struct(t *testing.T) {
type testStruct struct {
ID int
Name string
}
newRandomTestStruct := func(id int) testStruct {
return testStruct{
ID: id,
Name: fmt.Sprintf("test-%d", id),
}
}
testCases := []filterTestCase[testStruct]{
{
name: "empty",
input: []testStruct{},
keep: func(_ testStruct) bool {
return true
},
expected: []testStruct{},
},
{
name: "all",
input: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)},
keep: func(_ testStruct) bool {
return true
},
expected: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)},
},
{
name: "none",
input: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)},
keep: func(_ testStruct) bool {
return false
},
expected: []testStruct{},
},
{
name: "specific id",
input: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)},
keep: func(ts testStruct) bool {
return ts.ID == 2
},
expected: []testStruct{newRandomTestStruct(2)},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := Filter(tc.input, tc.keep)
require.Equal(t, tc.expected, result)
require.Len(t, result, len(tc.expected))
})
}
}

View File

@@ -2,10 +2,10 @@ package server
import (
"net/http"
"slices"
"github.com/ProtonMail/go-proton-api"
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
)
func (s *Server) handleGetAddresses() gin.HandlerFunc {

View File

@@ -12,9 +12,9 @@ import (
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/pkg/utils"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/maps"
)
func (b *Backend) GetUser(userID string) (proton.User, error) {
@@ -183,7 +183,7 @@ func (b *Backend) GetAddress(userID, addrID string) (proton.Address, error) {
func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) {
return readBackendRetErr(b, func(b *unsafeBackend) ([]proton.Address, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Address, error) {
return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address {
return xslices.Map(utils.Values(acc.addresses), func(add *address) proton.Address {
return add.toAddress()
}), nil
})
@@ -342,7 +342,7 @@ func (b *Backend) GetLabels(userID string, types ...proton.LabelType) ([]proton.
}
if len(types) > 0 {
res = proton.Filter(res, func(label proton.Label) bool {
res = utils.Filter(res, func(label proton.Label) bool {
return slices.Contains(types, label.Type)
})
}
@@ -432,7 +432,7 @@ func (b *Backend) DeleteLabel(userID, labelID string) error {
return err
}
acc.labelIDs = proton.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID })
acc.labelIDs = utils.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID })
acc.updateIDs = append(acc.updateIDs, updateID)
}
@@ -506,7 +506,7 @@ func (b *Backend) GetMessages(userID string, page, pageSize int, filter proton.M
}
}
metadata = proton.Filter(metadata, func(metadata proton.MessageMetadata) bool {
metadata = utils.Filter(metadata, func(metadata proton.MessageMetadata) bool {
if len(filter.ID) > 0 {
if !slices.Contains(filter.ID, metadata.ID) {
return false
@@ -693,7 +693,7 @@ func (b *Backend) DeleteMessage(userID, messageID string) error {
}
for _, attID := range message.attIDs {
if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool {
if xslices.CountFunc(utils.Values(b.attachments), func(att *attachment) bool {
return att.attDataID == b.attachments[attID].attDataID
}) == 1 {
delete(b.attData, b.attachments[attID].attDataID)
@@ -709,7 +709,7 @@ func (b *Backend) DeleteMessage(userID, messageID string) error {
return err
}
acc.messageIDs = proton.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID })
acc.messageIDs = utils.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID })
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
@@ -1235,7 +1235,7 @@ func (b *Backend) GetUserContacts(userID string, page int, pageSize int) (int, [
contacts, err := readBackendRetErr(b, func(b *unsafeBackend) ([]proton.Contact, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
total = len(acc.contacts)
values := maps.Values(acc.contacts)
values := utils.Values(acc.contacts)
slices.SortFunc(values, func(i, j *proton.Contact) int {
return strings.Compare(i.ID, j.ID)
})

View File

@@ -3,18 +3,18 @@ package backend
import (
"fmt"
"net/mail"
"slices"
"sync"
"time"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/pkg/utils"
"github.com/ProtonMail/go-srp"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
var log = logrus.WithField("pkg", "gpa/server/backend")
@@ -168,7 +168,7 @@ func (b *Backend) RemoveUser(userID string) error {
for _, messageID := range user.messageIDs {
for _, attID := range b.messages[messageID].attIDs {
if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool {
if xslices.CountFunc(utils.Values(b.attachments), func(att *attachment) bool {
return att.attDataID == b.attachments[attID].attDataID
}) == 1 {
delete(b.attData, b.attachments[attID].attDataID)

View File

@@ -2,14 +2,15 @@ package backend
import (
"net/mail"
"slices"
"strings"
"time"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/pkg/utils"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"golang.org/x/exp/slices"
)
type message struct {
@@ -320,7 +321,7 @@ func (msg *message) addLabel(labelID string, labels map[string]*label) {
}
func (msg *message) addFlagLabel(labelID string, labels map[string]*label) {
msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool {
msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool {
return labels[otherLabelID].labelType == proton.LabelTypeLabel
})
@@ -328,7 +329,7 @@ func (msg *message) addFlagLabel(labelID string, labels map[string]*label) {
}
func (msg *message) addSystemLabel(labelID string, labels map[string]*label) {
msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool {
msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool {
return labels[otherLabelID].labelType == proton.LabelTypeLabel
})
@@ -337,7 +338,7 @@ func (msg *message) addSystemLabel(labelID string, labels map[string]*label) {
func (msg *message) addUserLabel(label *label, labels map[string]*label) {
if label.labelType != proton.LabelTypeLabel {
msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool {
msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool {
return labels[otherLabelID].labelType == proton.LabelTypeLabel
})
@@ -380,7 +381,7 @@ func (msg *message) remSystemLabel(labelID string, labels map[string]*label) {
}
func (msg *message) remUserLabel(label *label, labels map[string]*label) {
msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool {
msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool {
return otherLabelID != label.labelID
})
}

View File

@@ -7,6 +7,7 @@ import (
"mime"
"net/http"
"net/mail"
"slices"
"strconv"
"time"
@@ -15,7 +16,6 @@ import (
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
)
const (
@@ -489,14 +489,14 @@ func (s *Server) parseMessage(literal []byte) (*rfc822.Header, []string, []*rfc8
mimeType = "multipart/mixed"
children, err := root.Children()
// or determine it if there is only one (non-attachment) child
if err == nil && (len(children) - len(atts)) <= 1 {
if err == nil && (len(children)-len(atts)) <= 1 {
var isHtml = false
var isTxt = false
for _, child := range children {
contentType, _, err := child.ContentType()
if err != nil {
continue
}else if contentType == rfc822.TextHTML {
} else if contentType == rfc822.TextHTML {
isHtml = true
} else if contentType == rfc822.TextPlain {
isTxt = true

View File

@@ -11,6 +11,7 @@ import (
"net/url"
"os"
"runtime"
"slices"
"sync"
"sync/atomic"
"testing"
@@ -20,6 +21,7 @@ import (
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/pkg/utils"
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/iterator"
@@ -30,7 +32,6 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestServer_LoginLogout(t *testing.T) {
@@ -2216,7 +2217,7 @@ func TestServer_GetMessageGroupCount(t *testing.T) {
counts, err := c.GetGroupedMessageCount(ctx)
require.NoError(t, err)
counts = proton.Filter(counts, func(t proton.MessageGroupCount) bool {
counts = utils.Filter(counts, func(t proton.MessageGroupCount) bool {
switch t.LabelID {
case proton.InboxLabel, proton.TrashLabel, proton.ArchiveLabel, proton.AllMailLabel, proton.SentLabel:
return true