diff --git a/go.mod b/go.mod index f439957..cc57996 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,6 @@ require ( github.com/urfave/cli/v2 v2.27.7 gitlab.com/c0b/go-ordered-json v0.0.0-20201030195603-febf46534d5a go.uber.org/goleak v1.3.0 - golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 golang.org/x/net v0.52.0 golang.org/x/text v0.36.0 google.golang.org/grpc v1.80.0 @@ -67,6 +66,7 @@ require ( go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/crypto v0.49.0 // indirect + golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect diff --git a/keyring.go b/keyring.go index 665c6fd..e567e54 100644 --- a/keyring.go +++ b/keyring.go @@ -10,6 +10,7 @@ import ( "github.com/ProtonMail/go-crypto/openpgp" "github.com/ProtonMail/go-crypto/openpgp/armor" + "github.com/ProtonMail/go-proton-api/pkg/utils" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" ) @@ -148,7 +149,7 @@ func (keys Keys) Unlock(passphrase []byte, userKR *crypto.KeyRing) (*crypto.KeyR return nil, err } - for _, key := range Filter(keys, func(key Key) bool { return bool(key.Active) }) { + for _, key := range utils.Filter(keys, func(key Key) bool { return bool(key.Active) }) { unlocked, err := key.Unlock(passphrase, userKR) if err != nil { log.WithField("KeyID", key.ID).WithError(err).Warning("Cannot unlock key") diff --git a/pkg/utils/maps.go b/pkg/utils/maps.go new file mode 100644 index 0000000..6b1afb9 --- /dev/null +++ b/pkg/utils/maps.go @@ -0,0 +1,28 @@ +package utils + +import ( + "maps" + "slices" +) + +// Keys returns a slice of keys from the map. +// Alternative to using maps.Keys which returns an iterator instead of a slice. +func Keys[M ~map[K]V, K comparable, V any](m M) []K { + keys := make([]K, 0, len(m)) + + return slices.AppendSeq( + keys, + maps.Keys(m), + ) +} + +// Values returns a slice of values from the map. +// Alternative to using maps.Values which returns an iterator instead of a slice. +func Values[M ~map[K]V, K comparable, V any](m M) []V { + values := make([]V, 0, len(m)) + + return slices.AppendSeq( + values, + maps.Values(m), + ) +} diff --git a/pkg/utils/maps_test.go b/pkg/utils/maps_test.go new file mode 100644 index 0000000..1a68042 --- /dev/null +++ b/pkg/utils/maps_test.go @@ -0,0 +1,241 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type keysTestCase[K comparable, V any] struct { + name string + expected []K + builder func() map[K]V +} + +type valuesTestCase[K comparable, V any] struct { + name string + expected []V + builder func() map[K]V +} + +func TestMaps_Keys_String(t *testing.T) { + testCases := []keysTestCase[string, int]{ + { + name: "empty", + builder: func() map[string]int { + return make(map[string]int, 0) + }, + expected: []string{}, + }, + { + name: "valid", + builder: func() map[string]int { + return map[string]int{ + "a": 1, + "b": 2, + "c": 3, + } + }, + expected: []string{"a", "b", "c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provided := tc.builder() + + result := Keys(provided) + require.Len(t, result, len(tc.expected)) + + for _, v := range result { + require.Contains(t, tc.expected, v) + } + }) + } +} + +func TestMaps_Keys_Int(t *testing.T) { + testCases := []keysTestCase[int, int]{ + { + name: "empty", + builder: func() map[int]int { + return make(map[int]int, 0) + }, + expected: []int{}, + }, + { + name: "valid", + builder: func() map[int]int { + return map[int]int{ + 1: 2, + 3: 4, + 5: 6, + } + }, + expected: []int{1, 3, 5}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provided := tc.builder() + + result := Keys(provided) + require.Len(t, result, len(tc.expected)) + + for _, v := range result { + require.Contains(t, tc.expected, v) + } + }) + } +} + +func TestMaps_Keys_Struct(t *testing.T) { + type customStruct struct { + ID string + Name string + } + + testCases := []keysTestCase[customStruct, int]{ + { + name: "empty", + builder: func() map[customStruct]int { + return make(map[customStruct]int, 0) + }, + expected: []customStruct{}, + }, + { + name: "valid", + builder: func() map[customStruct]int { + return map[customStruct]int{ + {ID: "1", Name: "test"}: 1, + } + }, + expected: []customStruct{{ID: "1", Name: "test"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provided := tc.builder() + + result := Keys(provided) + require.Len(t, result, len(tc.expected)) + + for _, v := range result { + require.Contains(t, tc.expected, v) + } + }) + } +} + +func TestMaps_Values_String(t *testing.T) { + testCases := []valuesTestCase[int, string]{ + { + name: "empty", + builder: func() map[int]string { + return make(map[int]string, 0) + }, + expected: []string{}, + }, + { + name: "valid", + builder: func() map[int]string { + return map[int]string{ + 1: "a", + 3: "b", + 5: "c", + } + }, + expected: []string{"a", "b", "c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provided := tc.builder() + + result := Values(provided) + require.Len(t, result, len(tc.expected)) + + for _, v := range result { + require.Contains(t, tc.expected, v) + } + }) + } +} + +func TestMaps_Values_Int(t *testing.T) { + testCases := []valuesTestCase[int, int]{ + { + name: "empty", + builder: func() map[int]int { + return make(map[int]int, 0) + }, + expected: []int{}, + }, + { + name: "valid", + builder: func() map[int]int { + return map[int]int{ + 1: 2, + 3: 4, + 5: 6, + } + }, + expected: []int{2, 4, 6}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provided := tc.builder() + + result := Values(provided) + require.Len(t, result, len(tc.expected)) + + for _, v := range result { + require.Contains(t, tc.expected, v) + } + }) + } +} + +func TestMaps_Values_Struct(t *testing.T) { + type customStruct struct { + ID string + Name string + } + + testCases := []valuesTestCase[int, customStruct]{ + { + name: "empty", + builder: func() map[int]customStruct { + return make(map[int]customStruct, 0) + }, + expected: []customStruct{}, + }, + { + name: "valid", + builder: func() map[int]customStruct { + return map[int]customStruct{ + 1: {ID: "1", Name: "test"}, + } + }, + expected: []customStruct{{ID: "1", Name: "test"}}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provided := tc.builder() + + result := Values(provided) + require.Len(t, result, len(tc.expected)) + + for _, v := range result { + require.Contains(t, tc.expected, v) + } + }) + } +} diff --git a/utils.go b/pkg/utils/slices.go similarity index 91% rename from utils.go rename to pkg/utils/slices.go index b234f48..92b49ba 100644 --- a/utils.go +++ b/pkg/utils/slices.go @@ -1,4 +1,4 @@ -package proton +package utils import "slices" diff --git a/pkg/utils/slices_test.go b/pkg/utils/slices_test.go new file mode 100644 index 0000000..697ca78 --- /dev/null +++ b/pkg/utils/slices_test.go @@ -0,0 +1,165 @@ +package utils + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type filterTestCase[T any] struct { + name string + input []T + keep func(T) bool + expected []T +} + +func TestSlice_Filter_Int(t *testing.T) { + testCases := []filterTestCase[int]{ + { + name: "empty", + input: []int{}, + keep: func(_ int) bool { + return true + }, + expected: []int{}, + }, + { + name: "all", + input: []int{1, 2, 3}, + keep: func(_ int) bool { + return true + }, + expected: []int{1, 2, 3}, + }, + { + name: "none", + input: []int{1, 2, 3}, + keep: func(_ int) bool { + return false + }, + expected: []int{}, + }, + { + name: "only one", + input: []int{1, 2, 3}, + keep: func(i int) bool { + return i == 2 + }, + expected: []int{2}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := Filter(tc.input, tc.keep) + + require.Equal(t, tc.expected, result) + require.Len(t, result, len(tc.expected)) + }) + } +} + +func TestSlice_Filter_String(t *testing.T) { + testCases := []filterTestCase[string]{ + { + name: "empty", + input: []string{}, + keep: func(_ string) bool { + return true + }, + expected: []string{}, + }, + { + name: "all", + input: []string{"a", "b", "c"}, + keep: func(_ string) bool { + return true + }, + expected: []string{"a", "b", "c"}, + }, + { + name: "none", + input: []string{"a", "b", "c"}, + keep: func(_ string) bool { + return false + }, + expected: []string{}, + }, + { + name: "only one", + input: []string{"a", "b", "c"}, + keep: func(s string) bool { + return s == "b" + }, + expected: []string{"b"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := Filter(tc.input, tc.keep) + + require.Equal(t, tc.expected, result) + require.Len(t, result, len(tc.expected)) + }) + } +} + +func TestSlice_Filter_Struct(t *testing.T) { + type testStruct struct { + ID int + Name string + } + + newRandomTestStruct := func(id int) testStruct { + return testStruct{ + ID: id, + Name: fmt.Sprintf("test-%d", id), + } + } + + testCases := []filterTestCase[testStruct]{ + { + name: "empty", + input: []testStruct{}, + keep: func(_ testStruct) bool { + return true + }, + expected: []testStruct{}, + }, + { + name: "all", + input: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)}, + keep: func(_ testStruct) bool { + return true + }, + expected: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)}, + }, + { + name: "none", + input: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)}, + keep: func(_ testStruct) bool { + return false + }, + expected: []testStruct{}, + }, + { + name: "specific id", + input: []testStruct{newRandomTestStruct(1), newRandomTestStruct(2), newRandomTestStruct(3)}, + keep: func(ts testStruct) bool { + return ts.ID == 2 + }, + expected: []testStruct{newRandomTestStruct(2)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := Filter(tc.input, tc.keep) + + require.Equal(t, tc.expected, result) + require.Len(t, result, len(tc.expected)) + }) + } +} diff --git a/server/addresses.go b/server/addresses.go index c127ea5..fbd8a3b 100644 --- a/server/addresses.go +++ b/server/addresses.go @@ -2,10 +2,10 @@ package server import ( "net/http" + "slices" "github.com/ProtonMail/go-proton-api" "github.com/gin-gonic/gin" - "golang.org/x/exp/slices" ) func (s *Server) handleGetAddresses() gin.HandlerFunc { diff --git a/server/backend/api.go b/server/backend/api.go index 1e23db9..50ac621 100644 --- a/server/backend/api.go +++ b/server/backend/api.go @@ -12,9 +12,9 @@ import ( "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/pkg/utils" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" - "golang.org/x/exp/maps" ) func (b *Backend) GetUser(userID string) (proton.User, error) { @@ -183,7 +183,7 @@ func (b *Backend) GetAddress(userID, addrID string) (proton.Address, error) { func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) { return readBackendRetErr(b, func(b *unsafeBackend) ([]proton.Address, error) { return withAcc(b, userID, func(acc *account) ([]proton.Address, error) { - return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address { + return xslices.Map(utils.Values(acc.addresses), func(add *address) proton.Address { return add.toAddress() }), nil }) @@ -342,7 +342,7 @@ func (b *Backend) GetLabels(userID string, types ...proton.LabelType) ([]proton. } if len(types) > 0 { - res = proton.Filter(res, func(label proton.Label) bool { + res = utils.Filter(res, func(label proton.Label) bool { return slices.Contains(types, label.Type) }) } @@ -432,7 +432,7 @@ func (b *Backend) DeleteLabel(userID, labelID string) error { return err } - acc.labelIDs = proton.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID }) + acc.labelIDs = utils.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID }) acc.updateIDs = append(acc.updateIDs, updateID) } @@ -506,7 +506,7 @@ func (b *Backend) GetMessages(userID string, page, pageSize int, filter proton.M } } - metadata = proton.Filter(metadata, func(metadata proton.MessageMetadata) bool { + metadata = utils.Filter(metadata, func(metadata proton.MessageMetadata) bool { if len(filter.ID) > 0 { if !slices.Contains(filter.ID, metadata.ID) { return false @@ -693,7 +693,7 @@ func (b *Backend) DeleteMessage(userID, messageID string) error { } for _, attID := range message.attIDs { - if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool { + if xslices.CountFunc(utils.Values(b.attachments), func(att *attachment) bool { return att.attDataID == b.attachments[attID].attDataID }) == 1 { delete(b.attData, b.attachments[attID].attDataID) @@ -709,7 +709,7 @@ func (b *Backend) DeleteMessage(userID, messageID string) error { return err } - acc.messageIDs = proton.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID }) + acc.messageIDs = utils.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID }) acc.updateIDs = append(acc.updateIDs, updateID) return nil @@ -1235,7 +1235,7 @@ func (b *Backend) GetUserContacts(userID string, page int, pageSize int) (int, [ contacts, err := readBackendRetErr(b, func(b *unsafeBackend) ([]proton.Contact, error) { return withAcc(b, userID, func(acc *account) ([]proton.Contact, error) { total = len(acc.contacts) - values := maps.Values(acc.contacts) + values := utils.Values(acc.contacts) slices.SortFunc(values, func(i, j *proton.Contact) int { return strings.Compare(i.ID, j.ID) }) diff --git a/server/backend/backend.go b/server/backend/backend.go index 41883e2..dd3570f 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -3,18 +3,18 @@ package backend import ( "fmt" "net/mail" + "slices" "sync" "time" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/pkg/utils" "github.com/ProtonMail/go-srp" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) var log = logrus.WithField("pkg", "gpa/server/backend") @@ -168,7 +168,7 @@ func (b *Backend) RemoveUser(userID string) error { for _, messageID := range user.messageIDs { for _, attID := range b.messages[messageID].attIDs { - if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool { + if xslices.CountFunc(utils.Values(b.attachments), func(att *attachment) bool { return att.attDataID == b.attachments[attID].attDataID }) == 1 { delete(b.attData, b.attachments[attID].attDataID) diff --git a/server/backend/message.go b/server/backend/message.go index 5361737..63e7ebd 100644 --- a/server/backend/message.go +++ b/server/backend/message.go @@ -2,14 +2,15 @@ package backend import ( "net/mail" + "slices" "strings" "time" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/pkg/utils" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" - "golang.org/x/exp/slices" ) type message struct { @@ -320,7 +321,7 @@ func (msg *message) addLabel(labelID string, labels map[string]*label) { } func (msg *message) addFlagLabel(labelID string, labels map[string]*label) { - msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool { + msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool { return labels[otherLabelID].labelType == proton.LabelTypeLabel }) @@ -328,7 +329,7 @@ func (msg *message) addFlagLabel(labelID string, labels map[string]*label) { } func (msg *message) addSystemLabel(labelID string, labels map[string]*label) { - msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool { + msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool { return labels[otherLabelID].labelType == proton.LabelTypeLabel }) @@ -337,7 +338,7 @@ func (msg *message) addSystemLabel(labelID string, labels map[string]*label) { func (msg *message) addUserLabel(label *label, labels map[string]*label) { if label.labelType != proton.LabelTypeLabel { - msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool { + msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool { return labels[otherLabelID].labelType == proton.LabelTypeLabel }) @@ -380,7 +381,7 @@ func (msg *message) remSystemLabel(labelID string, labels map[string]*label) { } func (msg *message) remUserLabel(label *label, labels map[string]*label) { - msg.labelIDs = proton.Filter(msg.labelIDs, func(otherLabelID string) bool { + msg.labelIDs = utils.Filter(msg.labelIDs, func(otherLabelID string) bool { return otherLabelID != label.labelID }) } diff --git a/server/messages.go b/server/messages.go index c3c6f8d..1b85090 100644 --- a/server/messages.go +++ b/server/messages.go @@ -7,6 +7,7 @@ import ( "mime" "net/http" "net/mail" + "slices" "strconv" "time" @@ -15,7 +16,6 @@ import ( "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/gin-gonic/gin" - "golang.org/x/exp/slices" ) const ( @@ -489,14 +489,14 @@ func (s *Server) parseMessage(literal []byte) (*rfc822.Header, []string, []*rfc8 mimeType = "multipart/mixed" children, err := root.Children() // or determine it if there is only one (non-attachment) child - if err == nil && (len(children) - len(atts)) <= 1 { + if err == nil && (len(children)-len(atts)) <= 1 { var isHtml = false var isTxt = false for _, child := range children { contentType, _, err := child.ContentType() if err != nil { continue - }else if contentType == rfc822.TextHTML { + } else if contentType == rfc822.TextHTML { isHtml = true } else if contentType == rfc822.TextPlain { isTxt = true diff --git a/server/server_test.go b/server/server_test.go index 80dd010..07c1634 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -11,6 +11,7 @@ import ( "net/url" "os" "runtime" + "slices" "sync" "sync/atomic" "testing" @@ -20,6 +21,7 @@ import ( "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/pkg/utils" "github.com/ProtonMail/go-proton-api/server/backend" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/iterator" @@ -30,7 +32,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/slices" ) func TestServer_LoginLogout(t *testing.T) { @@ -2216,7 +2217,7 @@ func TestServer_GetMessageGroupCount(t *testing.T) { counts, err := c.GetGroupedMessageCount(ctx) require.NoError(t, err) - counts = proton.Filter(counts, func(t proton.MessageGroupCount) bool { + counts = utils.Filter(counts, func(t proton.MessageGroupCount) bool { switch t.LabelID { case proton.InboxLabel, proton.TrashLabel, proton.ArchiveLabel, proton.AllMailLabel, proton.SentLabel: return true