fix(GODT-3117): Improve Contact Info Retrieval

Rather than only fetching the total in on request and discarding all the
data, re-use the first page of data and then collect more of them if the
data set exceeds the page size.

This patch also includes various fixes to the GPA server to mimic
proton server behavior.
This commit is contained in:
Leander Beernaert
2023-11-20 15:23:24 +01:00
committed by LBeernaertProton
parent 8a47c8d92f
commit 65479b90c4
7 changed files with 365 additions and 48 deletions

View File

@@ -50,8 +50,15 @@ func (c *Client) CountContactEmails(ctx context.Context, email string) (int, err
}
func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) {
_, contacts, err := c.getContactsImpl(ctx, page, pageSize)
return contacts, err
}
func (c *Client) getContactsImpl(ctx context.Context, page, pageSize int) (int, []Contact, error) {
var res struct {
Contacts []Contact
Total int
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
@@ -60,26 +67,58 @@ func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact
"PageSize": strconv.Itoa(pageSize),
}).SetResult(&res).Get("/contacts/v4")
}); err != nil {
return nil, err
return 0, nil, err
}
return res.Contacts, nil
return res.Total, res.Contacts, nil
}
func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) {
total, err := c.CountContacts(ctx)
return c.GetAllContactsPaged(ctx, maxPageSize)
}
func (c *Client) GetAllContactsPaged(ctx context.Context, pageSize int) ([]Contact, error) {
if pageSize > maxPageSize {
pageSize = maxPageSize
}
total, firstBatch, err := c.getContactsImpl(ctx, 0, pageSize)
if err != nil {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]Contact, error) {
return c.GetContacts(ctx, page, pageSize)
})
if total <= pageSize {
return firstBatch, nil
}
remainingPages := (total / pageSize) + 1
for i := 1; i < remainingPages; i++ {
_, batch, err := c.getContactsImpl(ctx, i, pageSize)
if err != nil {
return nil, err
}
firstBatch = append(firstBatch, batch...)
}
return firstBatch, err
}
func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) {
if pageSize > maxPageSize {
pageSize = maxPageSize
}
_, contacts, err := c.getContactEmailsImpl(ctx, email, page, pageSize)
return contacts, err
}
func (c *Client) getContactEmailsImpl(ctx context.Context, email string, page, pageSize int) (int, []ContactEmail, error) {
var res struct {
ContactEmails []ContactEmail
Total int
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
@@ -89,21 +128,42 @@ func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageS
"Email": email,
}).SetResult(&res).Get("/contacts/v4/emails")
}); err != nil {
return nil, err
return 0, nil, err
}
return res.ContactEmails, nil
return res.Total, res.ContactEmails, nil
}
func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) {
total, err := c.CountContactEmails(ctx, email)
return c.GetAllContactEmailsPaged(ctx, email, maxPageSize)
}
func (c *Client) GetAllContactEmailsPaged(ctx context.Context, email string, pageSize int) ([]ContactEmail, error) {
if pageSize > maxPageSize {
pageSize = maxPageSize
}
total, firstBatch, err := c.getContactEmailsImpl(ctx, email, 0, pageSize)
if err != nil {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) {
return c.GetContactEmails(ctx, email, page, pageSize)
})
if total <= pageSize {
return firstBatch, nil
}
remainingPages := (total / pageSize) + 1
for i := 1; i < remainingPages; i++ {
_, batch, err := c.getContactEmailsImpl(ctx, email, i, pageSize)
if err != nil {
return nil, err
}
firstBatch = append(firstBatch, batch...)
}
return firstBatch, err
}
func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) {

View File

@@ -109,6 +109,17 @@ func (c *Card) Set(kr *crypto.KeyRing, key string, value *vcard.Field) error {
return c.encode(kr, dec)
}
func (c *Card) Add(kr *crypto.KeyRing, key string, value *vcard.Field) error {
dec, err := c.decode(kr)
if err != nil {
return err
}
dec.Add(key, value)
return c.encode(kr, dec)
}
func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error {
dec, err := c.decode(kr)
if err != nil {

View File

@@ -10,12 +10,13 @@ import (
)
type account struct {
userID string
username string
addresses map[string]*address
mailSettings *mailSettings
userSettings proton.UserSettings
contacts map[string]*proton.Contact
userID string
username string
addresses map[string]*address
mailSettings *mailSettings
userSettings proton.UserSettings
contacts map[string]*proton.Contact
contactCounter int
auth map[string]auth
authLock sync.RWMutex

View File

@@ -1106,18 +1106,26 @@ func (b *Backend) GetUserContact(userID, contactID string) (proton.Contact, erro
})
}
func (b *Backend) GetUserContacts(userID string) ([]proton.Contact, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
var contacts []proton.Contact
for _, contact := range acc.contacts {
contacts = append(contacts, *contact)
}
return contacts, nil
func (b *Backend) GetUserContacts(userID string, page int, pageSize int) (int, []proton.Contact, error) {
var total int
contacts, err := withAcc(b, userID, func(acc *account) ([]proton.Contact, error) {
total = len(acc.contacts)
values := maps.Values(acc.contacts)
slices.SortFunc(values, func(i, j *proton.Contact) bool {
return strings.Compare(i.ID, j.ID) < 0
})
return xslices.Map(xslices.Chunk(values, pageSize)[page], func(c *proton.Contact) proton.Contact {
return *c
}), nil
})
return total, contacts, err
}
func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEmail, error) {
return withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) {
func (b *Backend) GetUserContactEmails(userID, email string, page int, pageSize int) (int, []proton.ContactEmail, error) {
var total int
emails, err := withAcc(b, userID, func(acc *account) ([]proton.ContactEmail, error) {
var contacts []proton.ContactEmail
for _, contact := range acc.contacts {
for _, contactEmail := range contact.ContactEmails {
@@ -1126,8 +1134,21 @@ func (b *Backend) GetUserContactEmails(userID, email string) ([]proton.ContactEm
}
}
}
return contacts, nil
total = len(contacts)
if total < pageSize {
return contacts, nil
}
slices.SortFunc(contacts, func(a, b proton.ContactEmail) bool {
return strings.Compare(a.ID, b.ID) < 0
})
return xslices.Chunk(contacts, pageSize)[page], nil
})
return total, emails, err
}
func (b *Backend) AddUserContact(userID string, contact proton.Contact) (proton.Contact, error) {
@@ -1146,15 +1167,8 @@ func (b *Backend) UpdateUserContact(userID, contactID string, cards proton.Cards
func (b *Backend) GenerateContactID(userID string) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
var lastKey = "0"
for k := range acc.contacts {
lastKey = k
}
newKey, err := strconv.Atoi(lastKey)
if err != nil {
return "", err
}
return strconv.Itoa(newKey + 1), nil
acc.contactCounter++
return strconv.Itoa(acc.contactCounter), nil
})
}

View File

@@ -3,9 +3,14 @@ package backend
import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-vcard"
"strconv"
"sync/atomic"
)
var globalContactID int32
func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRing) (proton.Contact, error) {
emails, err := card.Get(kr, vcard.FieldEmail)
if err != nil {
@@ -19,13 +24,15 @@ func ContactCardToContact(card *proton.Card, contactID string, kr *crypto.KeyRin
ContactMetadata: proton.ContactMetadata{
ID: contactID,
Name: names[0].Value,
ContactEmails: []proton.ContactEmail{proton.ContactEmail{
ID: "1",
Name: names[0].Value,
Email: emails[0].Value,
ContactID: contactID,
},
},
ContactEmails: xslices.Map(emails, func(email *vcard.Field) proton.ContactEmail {
id := atomic.AddInt32(&globalContactID, 1)
return proton.ContactEmail{
ID: strconv.Itoa(int(id)),
Name: names[0].Value,
Email: email.Value,
ContactID: contactID,
}
}),
},
ContactCards: proton.ContactCards{Cards: proton.Cards{card}},
}, nil

View File

@@ -2,6 +2,7 @@ package server
import (
"net/http"
"strconv"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server/backend"
@@ -10,22 +11,29 @@ import (
func (s *Server) handleGetContacts() gin.HandlerFunc {
return func(c *gin.Context) {
contacts, err := s.b.GetUserContacts(c.GetString("UserID"))
total, contacts, err := s.b.GetUserContacts(c.GetString("UserID"),
mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))),
mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))),
)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
c.JSON(http.StatusOK, gin.H{
"Code": proton.MultiCode,
"ContactEmails": contacts,
"Code": proton.MultiCode,
"Contacts": contacts,
"Total": total,
})
}
}
func (s *Server) handleGetContactsEmails() gin.HandlerFunc {
return func(c *gin.Context) {
contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.GetString("email"))
total, contacts, err := s.b.GetUserContactEmails(c.GetString("UserID"), c.Query("Email"),
mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))),
mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))),
)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
@@ -33,6 +41,7 @@ func (s *Server) handleGetContactsEmails() gin.HandlerFunc {
c.JSON(http.StatusOK, gin.H{
"Code": proton.MultiCode,
"ContactEmails": contacts,
"Total": total,
})
}
}

View File

@@ -26,7 +26,9 @@ import (
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-vcard"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
@@ -2314,6 +2316,219 @@ func TestServer_TestDraftActions(t *testing.T) {
})
}
func TestServer_Contacts(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
require.NoError(t, err)
type testContact struct {
Name string
Email string
}
testContacts := []testContact{
{
Name: "foo",
Email: "foo@bar.com",
},
{
Name: "bar",
Email: "bar@bar.com",
},
{
Name: "zz",
Email: "zz@bar.com",
},
}
contactDesc := []proton.ContactCards{
{
Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card {
return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Email)
}),
},
}
createReq := proton.CreateContactsReq{
Contacts: contactDesc,
Overwrite: 0,
Labels: 0,
}
contactsRes, err := c.CreateContacts(ctx, createReq)
require.NoError(t, err)
assert.Equal(t, 3, len(contactsRes))
contacts, err := c.GetAllContactsPaged(ctx, 2)
require.NoError(t, err)
require.Len(t, contacts, len(testContacts))
for _, v := range testContacts {
require.NotEqual(t, -1, xslices.IndexFunc(contacts, func(contact proton.Contact) bool {
return contact.Name == v.Name
}))
}
})
})
}
func TestServer_ContactEmails(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
require.NoError(t, err)
type testContact struct {
Name string
Emails []string
}
testContacts := []testContact{
{
Name: "foo",
Emails: []string{"foo@bar.com", "alias@alias.com", "nn@zz.com", "abc@4.de", "001234@00.com"},
},
{
Name: "bar",
Emails: []string{"bar@bar.com"},
},
{
Name: "zz",
Emails: []string{"zz@bar.com", "zz@zz2.com"},
},
}
contactDesc := []proton.ContactCards{
{
Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card {
return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Emails...)
}),
},
}
createReq := proton.CreateContactsReq{
Contacts: contactDesc,
Overwrite: 0,
Labels: 0,
}
contactsRes, err := c.CreateContacts(ctx, createReq)
require.NoError(t, err)
assert.Equal(t, 3, len(contactsRes))
for _, v := range testContacts {
for _, email := range v.Emails {
emails, err := c.GetAllContactEmailsPaged(ctx, email, 2)
require.NoError(t, err)
require.Len(t, emails, 1)
assert.Equal(t, email, emails[0].Email)
}
}
})
})
}
func TestServer_ContactEmailsRepeated(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
require.NoError(t, err)
type testContact struct {
Name string
Emails []string
}
testContacts := []testContact{
{
Name: "foo",
Emails: []string{"foo@bar.com"},
},
{
Name: "bar",
Emails: []string{"foo@bar.com"},
},
{
Name: "zz",
Emails: []string{"foo@bar.com"},
},
}
contactDesc := []proton.ContactCards{
{
Cards: xslices.Map(testContacts, func(contact testContact) *proton.Card {
return createVCard(t, addrKRs[addr[0].ID], contact.Name, contact.Emails...)
}),
},
}
createReq := proton.CreateContactsReq{
Contacts: contactDesc,
Overwrite: 0,
Labels: 0,
}
contactsRes, err := c.CreateContacts(ctx, createReq)
require.NoError(t, err)
assert.Equal(t, 3, len(contactsRes))
emails, err := c.GetAllContactEmailsPaged(ctx, "foo@bar.com", 2)
require.NoError(t, err)
require.Len(t, emails, len(testContacts))
})
})
}
func createVCard(t *testing.T, addrKR *crypto.KeyRing, name string, email ...string) *proton.Card {
card, err := proton.NewCard(addrKR, proton.CardTypeSigned)
require.NoError(t, err)
require.NoError(t, card.Set(addrKR, vcard.FieldUID, &vcard.Field{Value: fmt.Sprintf("proton-legacy-%v", uuid.NewString()), Group: "test"}))
require.NoError(t, card.Set(addrKR, vcard.FieldFormattedName, &vcard.Field{Value: name, Group: "test"}))
for _, email := range email {
require.NoError(t, card.Add(addrKR, vcard.FieldEmail, &vcard.Field{Value: email, Group: "test"}))
}
return card
}
func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()