mirror of
https://github.com/ProtonMail/go-proton-api.git
synced 2025-12-23 23:57:50 -05:00
fix(GODT-3117): Improve Contact Info Retrieval
Rather than only fetching the total in on request and discarding all the data, re-use the first page of data and then collect more of them if the data set exceeds the page size. This patch also includes various fixes to the GPA server to mimic proton server behavior.
This commit is contained in:
committed by
LBeernaertProton
parent
8a47c8d92f
commit
65479b90c4
84
contact.go
84
contact.go
@@ -50,8 +50,15 @@ func (c *Client) CountContactEmails(ctx context.Context, email string) (int, err
|
||||
}
|
||||
|
||||
func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) {
|
||||
_, contacts, err := c.getContactsImpl(ctx, page, pageSize)
|
||||
|
||||
return contacts, err
|
||||
}
|
||||
|
||||
func (c *Client) getContactsImpl(ctx context.Context, page, pageSize int) (int, []Contact, error) {
|
||||
var res struct {
|
||||
Contacts []Contact
|
||||
Total int
|
||||
}
|
||||
|
||||
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
@@ -60,26 +67,58 @@ func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact
|
||||
"PageSize": strconv.Itoa(pageSize),
|
||||
}).SetResult(&res).Get("/contacts/v4")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return res.Contacts, nil
|
||||
return res.Total, res.Contacts, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) {
|
||||
total, err := c.CountContacts(ctx)
|
||||
return c.GetAllContactsPaged(ctx, maxPageSize)
|
||||
}
|
||||
|
||||
func (c *Client) GetAllContactsPaged(ctx context.Context, pageSize int) ([]Contact, error) {
|
||||
if pageSize > maxPageSize {
|
||||
pageSize = maxPageSize
|
||||
}
|
||||
|
||||
total, firstBatch, err := c.getContactsImpl(ctx, 0, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]Contact, error) {
|
||||
return c.GetContacts(ctx, page, pageSize)
|
||||
})
|
||||
if total <= pageSize {
|
||||
return firstBatch, nil
|
||||
}
|
||||
|
||||
remainingPages := (total / pageSize) + 1
|
||||
|
||||
for i := 1; i < remainingPages; i++ {
|
||||
_, batch, err := c.getContactsImpl(ctx, i, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
firstBatch = append(firstBatch, batch...)
|
||||
}
|
||||
|
||||
return firstBatch, err
|
||||
}
|
||||
|
||||
func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) {
|
||||
if pageSize > maxPageSize {
|
||||
pageSize = maxPageSize
|
||||
}
|
||||
|
||||
_, contacts, err := c.getContactEmailsImpl(ctx, email, page, pageSize)
|
||||
|
||||
return contacts, err
|
||||
}
|
||||
|
||||
func (c *Client) getContactEmailsImpl(ctx context.Context, email string, page, pageSize int) (int, []ContactEmail, error) {
|
||||
var res struct {
|
||||
ContactEmails []ContactEmail
|
||||
Total int
|
||||
}
|
||||
|
||||
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
|
||||
@@ -89,21 +128,42 @@ func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageS
|
||||
"Email": email,
|
||||
}).SetResult(&res).Get("/contacts/v4/emails")
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return res.ContactEmails, nil
|
||||
return res.Total, res.ContactEmails, nil
|
||||
}
|
||||
|
||||
func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) {
|
||||
total, err := c.CountContactEmails(ctx, email)
|
||||
return c.GetAllContactEmailsPaged(ctx, email, maxPageSize)
|
||||
}
|
||||
|
||||
func (c *Client) GetAllContactEmailsPaged(ctx context.Context, email string, pageSize int) ([]ContactEmail, error) {
|
||||
if pageSize > maxPageSize {
|
||||
pageSize = maxPageSize
|
||||
}
|
||||
|
||||
total, firstBatch, err := c.getContactEmailsImpl(ctx, email, 0, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) {
|
||||
return c.GetContactEmails(ctx, email, page, pageSize)
|
||||
})
|
||||
if total <= pageSize {
|
||||
return firstBatch, nil
|
||||
}
|
||||
|
||||
remainingPages := (total / pageSize) + 1
|
||||
|
||||
for i := 1; i < remainingPages; i++ {
|
||||
_, batch, err := c.getContactEmailsImpl(ctx, email, i, pageSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
firstBatch = append(firstBatch, batch...)
|
||||
}
|
||||
|
||||
return firstBatch, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) {
|
||||
|
||||
@@ -109,6 +109,17 @@ func (c *Card) Set(kr *crypto.KeyRing, key string, value *vcard.Field) error {
|
||||
return c.encode(kr, dec)
|
||||
}
|
||||
|
||||
func (c *Card) Add(kr *crypto.KeyRing, key string, value *vcard.Field) error {
|
||||
dec, err := c.decode(kr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dec.Add(key, value)
|
||||
|
||||
return c.encode(kr, dec)
|
||||
}
|
||||
|
||||
func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error {
|
||||
dec, err := c.decode(kr)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,12 +10,13 @@ import (
|
||||
)
|
||||
|
||||
type account struct {
|
||||
userID string
|
||||
username string
|
||||
addresses map[string]*address
|
||||
mailSettings *mailSettings
|
||||
userSettings proton.UserSettings
|
||||
contacts map[string]*proton.Contact
|
||||
userID string
|
||||
username string
|
||||
addresses map[string]*address
|
||||
mailSettings *mailSettings
|
||||
userSettings proton.UserSettings
|
||||
contacts map[string]*proton.Contact
|
||||
contactCounter int
|
||||
|
||||
auth map[string]auth
|
||||
authLock sync.RWMutex
|
||||
|
||||
@@ -1106,18 +1106,26 @@ func (b *Backend) GetUserContact(userID, contactID string) (proton.Contact, erro
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Backend) GetUserContacts(userID string) ([]proton.Contact, error) {
|
||||
return withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
|
||||
var contacts []proton.Contact
|
||||
for _, contact := range acc.contacts {
|
||||
contacts = append(contacts, *contact)
|
||||
}
|
||||
return contacts, nil
|
||||
func (b *Backend) GetUserContacts(userID string, page int, pageSize int) (int, []proton.Contact, error) {
|
||||
var total int
|
||||
contacts, err := withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
|
||||
total = len(acc.contacts)
|
||||
values := maps.Values(acc.contacts)
|
||||
slices.SortFunc(values, func(i, j *proton.Contact) bool {
|
||||
return strings.Compare(i.ID, j.ID) < 0
|
||||
})
|
||||
return xslices.Map(xslices.Chunk(values, pageSize)[page], func(c *proton.Contact) proton.Contact {
|
||||
return *c
|
||||
}), nil
|
||||
})
|
||||
|
||||
return total, contacts, err
|
||||
}
|
||||
|
||||
func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEmail, error) {
|
||||
return withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) {
|
||||
func (b *Backend) GetUserContactEmails(userID, email string, page int, pageSize int) (int, []proton.ContactEmail, error) {
|
||||
var total int
|
||||
|
||||
emails, err := withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) {
|
||||
var contacts []proton.ContactEmail
|
||||
for _, contact := range acc.contacts {
|
||||
for _, contactEmail := range contact.ContactEmails {
|
||||
@@ -1126,8 +1134,21 @@ func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEm
|
||||
}
|
||||
}
|
||||
}
|
||||
return contacts, nil
|
||||
|
||||
total = len(contacts)
|
||||
|
||||
if total < pageSize {
|
||||
return contacts, nil
|
||||
}
|
||||
|
||||
slices.SortFunc(contacts, func(a, b proton.ContactEmail) bool {
|
||||
return strings.Compare(a.ID, b.ID) < 0
|
||||
})
|
||||
|
||||
return xslices.Chunk(contacts, pageSize)[page], nil
|
||||
})
|
||||
|
||||
return total, emails, err
|
||||
}
|
||||
|
||||
func (b *Backend) AddUserContact(userID string, contact proton.Contact) (proton.Contact, error) {
|
||||
@@ -1146,15 +1167,8 @@ func (b *Backend) UpdateUserContact(userID, contactID string, cards proton.Cards
|
||||
|
||||
func (b *Backend) GenerateContactID(userID string) (string, error) {
|
||||
return withAcc(b, userID, func(acc *account) (string, error) {
|
||||
var lastKey = "0"
|
||||
for k := range acc.contacts {
|
||||
lastKey = k
|
||||
}
|
||||
newKey, err := strconv.Atoi(lastKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.Itoa(newKey + 1), nil
|
||||
acc.contactCounter++
|
||||
return strconv.Itoa(acc.contactCounter), nil
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -3,9 +3,14 @@ package backend
|
||||
import (
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-vcard"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var globalContactID int32
|
||||
|
||||
func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRing) (proton.Contact, error) {
|
||||
emails, err := card.Get(kr, vcard.FieldEmail)
|
||||
if err != nil {
|
||||
@@ -19,13 +24,15 @@ func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRin
|
||||
ContactMetadata: proton.ContactMetadata{
|
||||
ID: contactID,
|
||||
Name: names[0].Value,
|
||||
ContactEmails: []proton.ContactEmail{proton.ContactEmail{
|
||||
ID: "1",
|
||||
Name: names[0].Value,
|
||||
Email: emails[0].Value,
|
||||
ContactID: contactID,
|
||||
},
|
||||
},
|
||||
ContactEmails: xslices.Map(emails, func(email *vcard.Field) proton.ContactEmail {
|
||||
id := atomic.AddInt32(&globalContactID, 1)
|
||||
return proton.ContactEmail{
|
||||
ID: strconv.Itoa(int(id)),
|
||||
Name: names[0].Value,
|
||||
Email: email.Value,
|
||||
ContactID: contactID,
|
||||
}
|
||||
}),
|
||||
},
|
||||
ContactCards: proton.ContactCards{Cards: proton.Cards{card}},
|
||||
}, nil
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server/backend"
|
||||
@@ -10,22 +11,29 @@ import (
|
||||
|
||||
func (s *Server) handleGetContacts() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
contacts, err := s.b.GetUserContacts(c.GetString("UserID"))
|
||||
total, contacts, err := s.b.GetUserContacts(c.GetString("UserID"),
|
||||
mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))),
|
||||
mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))),
|
||||
)
|
||||
if err != nil {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"Code": proton.MultiCode,
|
||||
"ContactEmails": contacts,
|
||||
"Code": proton.MultiCode,
|
||||
"Contacts": contacts,
|
||||
"Total": total,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleGetContactsEmails() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.GetString("email"))
|
||||
total, contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.Query("Email"),
|
||||
mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))),
|
||||
mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))),
|
||||
)
|
||||
if err != nil {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
@@ -33,6 +41,7 @@ func (s *Server) handleGetContactsEmails() gin.HandlerFunc {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"Code": proton.MultiCode,
|
||||
"ContactEmails": contacts,
|
||||
"Total": total,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,9 @@ import (
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/stream"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-vcard"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
@@ -2314,6 +2316,219 @@ func TestServer_TestDraftActions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_Contacts(t *testing.T) {
|
||||
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
|
||||
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
|
||||
|
||||
user, err := c.GetUser(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr, err := c.GetAddresses(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
salt, err := c.GetSalts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
type testContact struct {
|
||||
Name string
|
||||
Email string
|
||||
}
|
||||
|
||||
testContacts := []testContact{
|
||||
{
|
||||
Name: "foo",
|
||||
Email: "foo@bar.com",
|
||||
},
|
||||
{
|
||||
Name: "bar",
|
||||
Email: "bar@bar.com",
|
||||
},
|
||||
{
|
||||
Name: "zz",
|
||||
Email: "zz@bar.com",
|
||||
},
|
||||
}
|
||||
|
||||
contactDesc := []proton.ContactCards{
|
||||
{
|
||||
Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card {
|
||||
return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Email)
|
||||
}),
|
||||
},
|
||||
}
|
||||
createReq := proton.CreateContactsReq{
|
||||
Contacts: contactDesc,
|
||||
Overwrite: 0,
|
||||
Labels: 0,
|
||||
}
|
||||
|
||||
contactsRes, err := c.CreateContacts(ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, len(contactsRes))
|
||||
|
||||
contacts, err := c.GetAllContactsPaged(ctx, 2)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, contacts, len(testContacts))
|
||||
|
||||
for _, v := range testContacts {
|
||||
require.NotEqual(t, -1, xslices.IndexFunc(contacts, func(contact proton.Contact) bool {
|
||||
return contact.Name == v.Name
|
||||
}))
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_ContactEmails(t *testing.T) {
|
||||
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
|
||||
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
|
||||
|
||||
user, err := c.GetUser(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr, err := c.GetAddresses(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
salt, err := c.GetSalts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
type testContact struct {
|
||||
Name string
|
||||
Emails []string
|
||||
}
|
||||
|
||||
testContacts := []testContact{
|
||||
{
|
||||
Name: "foo",
|
||||
Emails: []string{"foo@bar.com", "alias@alias.com", "nn@zz.com", "abc@4.de", "001234@00.com"},
|
||||
},
|
||||
{
|
||||
Name: "bar",
|
||||
Emails: []string{"bar@bar.com"},
|
||||
},
|
||||
{
|
||||
Name: "zz",
|
||||
Emails: []string{"zz@bar.com", "zz@zz2.com"},
|
||||
},
|
||||
}
|
||||
|
||||
contactDesc := []proton.ContactCards{
|
||||
{
|
||||
Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card {
|
||||
return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Emails...)
|
||||
}),
|
||||
},
|
||||
}
|
||||
createReq := proton.CreateContactsReq{
|
||||
Contacts: contactDesc,
|
||||
Overwrite: 0,
|
||||
Labels: 0,
|
||||
}
|
||||
|
||||
contactsRes, err := c.CreateContacts(ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, len(contactsRes))
|
||||
|
||||
for _, v := range testContacts {
|
||||
for _, email := range v.Emails {
|
||||
emails, err := c.GetAllContactEmailsPaged(ctx, email, 2)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, emails, 1)
|
||||
assert.Equal(t, email, emails[0].Email)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServer_ContactEmailsRepeated(t *testing.T) {
|
||||
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
|
||||
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
|
||||
|
||||
user, err := c.GetUser(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr, err := c.GetAddresses(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
salt, err := c.GetSalts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
type testContact struct {
|
||||
Name string
|
||||
Emails []string
|
||||
}
|
||||
|
||||
testContacts := []testContact{
|
||||
{
|
||||
Name: "foo",
|
||||
Emails: []string{"foo@bar.com"},
|
||||
},
|
||||
{
|
||||
Name: "bar",
|
||||
Emails: []string{"foo@bar.com"},
|
||||
},
|
||||
{
|
||||
Name: "zz",
|
||||
Emails: []string{"foo@bar.com"},
|
||||
},
|
||||
}
|
||||
|
||||
contactDesc := []proton.ContactCards{
|
||||
{
|
||||
Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card {
|
||||
return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Emails...)
|
||||
}),
|
||||
},
|
||||
}
|
||||
createReq := proton.CreateContactsReq{
|
||||
Contacts: contactDesc,
|
||||
Overwrite: 0,
|
||||
Labels: 0,
|
||||
}
|
||||
|
||||
contactsRes, err := c.CreateContacts(ctx, createReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, len(contactsRes))
|
||||
|
||||
emails, err := c.GetAllContactEmailsPaged(ctx, "foo@bar.com", 2)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, emails, len(testContacts))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func createVCard(t *testing.T, addrKR *crypto.KeyRing, name string, email ...string) *proton.Card {
|
||||
card, err := proton.NewCard(addrKR, proton.CardTypeSigned)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, card.Set(addrKR, vcard.FieldUID, &vcard.Field{Value: fmt.Sprintf("proton-legacy-%v", uuid.NewString()), Group: "test"}))
|
||||
require.NoError(t, card.Set(addrKR, vcard.FieldFormattedName, &vcard.Field{Value: name, Group: "test"}))
|
||||
for _, email := range email {
|
||||
require.NoError(t, card.Add(addrKR, vcard.FieldEmail, &vcard.Field{Value: email, Group: "test"}))
|
||||
}
|
||||
|
||||
return card
|
||||
}
|
||||
|
||||
func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
Reference in New Issue
Block a user