Files
go-proton-api/server/server_test.go
2022-12-13 13:12:36 +01:00

1808 lines
54 KiB
Go

package server
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"net/mail"
"net/url"
"os"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestServer_LoginLogout(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
require.Equal(t, "user", user.Name)
require.Equal(t, "user@"+s.GetDomain(), user.Email)
// Logout from the test API.
require.NoError(t, c.AuthDelete(ctx))
// Future requests should fail.
require.Error(t, c.AuthDelete(ctx))
})
})
}
func TestServerMulti(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
_, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
// Create one client.
c1, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer c1.Close()
// Create another client.
c2, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer c2.Close()
// Both clients should be able to make requests.
must(c1.GetUser(ctx))
must(c2.GetUser(ctx))
// Logout the first client; it should no longer be able to make requests.
require.NoError(t, c1.AuthDelete(ctx))
require.Panics(t, func() { must(c1.GetUser(ctx)) })
// The second client should still be able to make requests.
must(c2.GetUser(ctx))
})
}
func TestServer_Ping(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, _ *proton.Manager) {
ctl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
var status proton.Status
m.AddStatusObserver(func(s proton.Status) {
status = s
})
// When the network goes down, ping should fail.
ctl.Disable()
require.Error(t, m.Ping(ctx))
require.Equal(t, proton.StatusDown, status)
// When the network goes up, ping should succeed.
ctl.Enable()
require.NoError(t, m.Ping(ctx))
require.Equal(t, proton.StatusUp, status)
// When the API is down, ping should still succeed if the API is at least reachable.
s.SetOffline(true)
require.NoError(t, m.Ping(ctx))
require.Equal(t, proton.StatusUp, status)
// When the API is down, ping should fail if the API cannot be reached.
ctl.Disable()
require.Error(t, m.Ping(ctx))
require.Equal(t, proton.StatusDown, status)
// When the network goes up, ping should succeed, even if the API is down.
ctl.Enable()
require.NoError(t, m.Ping(ctx))
require.Equal(t, proton.StatusUp, status)
// When the API comes back alive, ping should succeed.
s.SetOffline(false)
require.NoError(t, m.Ping(ctx))
require.Equal(t, proton.StatusUp, status)
})
}
func TestServer_Bool(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1, func([]string) {
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err)
// By default the message is unread.
require.True(t, bool(must(c.GetMessage(ctx, metadata[0].ID)).Unread))
// Mark the message as read.
require.NoError(t, c.MarkMessagesRead(ctx, metadata[0].ID))
// Now the message is read.
require.False(t, bool(must(c.GetMessage(ctx, metadata[0].ID)).Unread))
})
})
})
}
func TestServer_Messages(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the messages.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err)
// The messages should be the ones we created.
require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string {
return metadata.ID
}))
// The messages should be in All Mail and should be unread.
for _, message := range metadata {
require.True(t, bool(message.Unread))
require.Equal(t, []string{proton.AllMailLabel}, message.LabelIDs)
}
// Mark the first three messages as read and put them in archive.
require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[0], messageIDs[1], messageIDs[2]))
require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[0], messageIDs[1], messageIDs[2]}, proton.ArchiveLabel))
// They should now be read.
require.False(t, bool(must(c.GetMessage(ctx, messageIDs[0])).Unread))
require.False(t, bool(must(c.GetMessage(ctx, messageIDs[1])).Unread))
require.False(t, bool(must(c.GetMessage(ctx, messageIDs[2])).Unread))
// They should now be in archive.
require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[0])).LabelIDs)
require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[1])).LabelIDs)
require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[2])).LabelIDs)
// Put them in inbox.
require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[0], messageIDs[1], messageIDs[2]}, proton.ArchiveLabel))
})
})
})
}
func TestServer_MessageFilter(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the messages.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err)
// The messages should be the ones we created.
require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string {
return metadata.ID
}))
// Get metadata for just the first three messages.
partial, err := c.GetMessageMetadata(ctx, proton.MessageFilter{
ID: []string{
metadata[0].ID,
metadata[1].ID,
metadata[2].ID,
},
})
require.NoError(t, err)
// The messages should be just the first three.
require.Equal(t, metadata[:3], partial)
})
})
})
}
func TestServer_MessageIDs(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 10000, func(wantMessageIDs []string) {
allMessageIDs, err := c.GetMessageIDs(ctx, "")
require.NoError(t, err)
require.ElementsMatch(t, wantMessageIDs, allMessageIDs)
halfMessageIDs, err := c.GetMessageIDs(ctx, allMessageIDs[len(allMessageIDs)/2])
require.NoError(t, err)
require.ElementsMatch(t, allMessageIDs[len(allMessageIDs)/2+1:], halfMessageIDs)
})
})
})
}
func TestServer_MessagesDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the messages.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err)
// The messages should be the ones we created.
require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string {
return metadata.ID
}))
// Delete half the messages.
require.NoError(t, c.DeleteMessage(ctx, messageIDs[0:500]...))
// Get the remaining messages.
remaining, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err)
// The remaining messages should be the ones we didn't delete.
require.ElementsMatch(t, messageIDs[500:], xslices.Map(remaining, func(metadata proton.MessageMetadata) string {
return metadata.ID
}))
})
})
})
}
func TestServer_MessagesDeleteAfterUpdate(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the initial event ID.
eventID, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
// Put half the messages in archive.
require.NoError(t, c.LabelMessages(ctx, messageIDs[0:500], proton.ArchiveLabel))
// Delete half the messages.
require.NoError(t, c.DeleteMessage(ctx, messageIDs[0:500]...))
// Get the event reflecting this change.
event, err := c.GetEvent(ctx, eventID)
require.NoError(t, err)
// The event should have the correct number of message events.
require.Len(t, event.Messages, 500)
// All the events should be delete events.
for _, message := range event.Messages {
require.Equal(t, proton.EventDelete, message.Action)
}
})
})
})
}
func TestServer_Events(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 3, func(messageIDs []string) {
// Get the latest event ID to stream from.
fromEventID, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
// Begin collecting events.
eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID)
// Mark a message as read.
require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[0]))
// The message should eventually be read.
require.Eventually(t, func() bool {
event := <-eventCh
if len(event.Messages) != 1 {
return false
}
if event.Messages[0].ID != messageIDs[0] {
return false
}
return !bool(event.Messages[0].Message.Unread)
}, 5*time.Second, time.Millisecond*100)
// Add another message to archive.
require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[1]}, proton.ArchiveLabel))
// The message should eventually be in archive and all mail.
require.Eventually(t, func() bool {
event := <-eventCh
if len(event.Messages) != 1 {
return false
}
if event.Messages[0].ID != messageIDs[1] {
return false
}
return elementsMatch([]string{proton.ArchiveLabel, proton.AllMailLabel}, event.Messages[0].Message.LabelIDs)
}, 5*time.Second, time.Millisecond*100)
// Perform a sequence of actions on the same message.
require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[2]}, proton.TrashLabel))
require.NoError(t, c.UnlabelMessages(ctx, []string{messageIDs[2]}, proton.TrashLabel))
require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[2]))
require.NoError(t, c.MarkMessagesUnread(ctx, messageIDs[2]))
// The final state of the message should be correct.
require.Eventually(t, func() bool {
event := <-eventCh
if len(event.Messages) != 1 {
return false
}
if event.Messages[0].ID != messageIDs[2] {
return false
}
return bool(event.Messages[0].Message.Unread) && elementsMatch([]string{proton.AllMailLabel}, event.Messages[0].Message.LabelIDs)
}, 5*time.Second, time.Millisecond*100)
// No more events should be sent.
select {
case <-eventCh:
t.Fatal("unexpected event")
default:
// ....
}
})
})
})
}
func TestServer_Events_Multi(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
for i := 0; i < 10; i++ {
withUser(ctx, t, s, m, fmt.Sprintf("user%v", i), "pass", func(c *proton.Client) {
latest, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
// Fetching latest again should return the same event ID.
latestAgain, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
require.Equal(t, latest, latestAgain)
event, err := c.GetEvent(ctx, latest)
require.NoError(t, err)
// The event should be empty.
require.Equal(t, proton.Event{EventID: event.EventID}, event)
// After fetching an empty event, its ID should still be the latest.
eventAgain, err := c.GetEvent(ctx, event.EventID)
require.NoError(t, err)
require.Equal(t, eventAgain.EventID, event.EventID)
})
}
})
}
func TestServer_Events_Refresh(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
// Get the latest event ID to stream from.
fromEventID, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
// Refresh the user's mail.
require.NoError(t, s.RefreshUser(user.ID, proton.RefreshMail))
// Begin collecting events.
eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID)
// The user should eventually be refreshed.
require.Eventually(t, func() bool {
return (<-eventCh).Refresh&proton.RefreshMail != 0
}, 5*time.Second, time.Millisecond*100)
})
})
}
func TestServer_RevokeUser(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
require.Equal(t, "user", user.Name)
require.Equal(t, "user@"+s.GetDomain(), user.Email)
// Revoke the user's auth.
require.NoError(t, s.RevokeUser(user.ID))
// Future requests should fail.
require.Error(t, c.AuthDelete(ctx))
})
})
}
func TestServer_Calls(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
var calls []Call
// Watch calls that are made.
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
// Get the user.
_, err := c.GetUser(ctx)
require.NoError(t, err)
// Logout the user.
require.NoError(t, c.AuthDelete(ctx))
// The user call should be correct.
userCall := calls[0]
require.Equal(t, "/core/v4/users", userCall.URL.Path)
require.Equal(t, "GET", userCall.Method)
require.Equal(t, http.StatusOK, userCall.Status)
// The logout call should be correct.
logoutCall := calls[1]
require.Equal(t, "/core/v4/auth", logoutCall.URL.Path)
require.Equal(t, "DELETE", logoutCall.Method)
require.Equal(t, http.StatusOK, logoutCall.Status)
})
})
}
func TestServer_Calls_Status(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
var calls []Call
// Watch calls that are made.
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
// Make a bad call.
_, err := c.GetMessage(ctx, "no such message ID")
require.Error(t, err)
// The user call should have error status.
require.Equal(t, http.StatusUnprocessableEntity, calls[0].Status)
})
})
}
func TestServer_Calls_Request(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
var calls []Call
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) {
require.Equal(
t,
calls[0].RequestBody,
must(json.Marshal(proton.AuthInfoReq{Username: "user"})),
)
})
})
}
func TestServer_Calls_Response(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
var calls []Call
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
salts, err := c.GetSalts(ctx)
require.NoError(t, err)
require.Equal(
t,
calls[len(calls)-1].ResponseBody,
must(json.Marshal(struct{ KeySalts []proton.Salt }{salts})),
)
})
})
}
func TestServer_Calls_Cookies(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
var calls []Call
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) {
// The header in the first call's response should set the Session-Id cookie.
resHeader := (&http.Response{Header: calls[len(calls)-2].ResponseHeader})
require.Len(t, resHeader.Cookies(), 1)
require.Equal(t, "Session-Id", resHeader.Cookies()[0].Name)
require.NotEmpty(t, resHeader.Cookies()[0].Value)
// The cookie should be sent in the next call.
reqHeader := (&http.Request{Header: calls[len(calls)-1].RequestHeader})
require.Len(t, reqHeader.Cookies(), 1)
require.Equal(t, "Session-Id", reqHeader.Cookies()[0].Name)
require.NotEmpty(t, reqHeader.Cookies()[0].Value)
// The cookie should be the same.
require.Equal(t, resHeader.Cookies()[0].Value, reqHeader.Cookies()[0].Value)
})
})
}
func TestServer_Calls_Manager(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
var calls []Call
// Watch calls that are made.
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
// Make a non-user request.
require.NoError(t, m.ReportBug(ctx, proton.ReportBugReq{}))
// The call should be correct.
reportCall := calls[0]
require.Equal(t, "/core/v4/reports/bug", reportCall.URL.Path)
require.Equal(t, "POST", reportCall.Method)
require.Equal(t, http.StatusOK, reportCall.Status)
})
}
func TestServer_CreateMessage(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
require.NoError(t, err)
draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{
Message: proton.DraftTemplate{
Subject: "My subject",
Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "recipient@example.com"}},
},
})
require.NoError(t, err)
require.Equal(t, addr[0].ID, draft.AddressID)
require.Equal(t, "My subject", draft.Subject)
require.Equal(t, &mail.Address{Address: "user@" + s.GetDomain()}, draft.Sender)
require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, draft.ToList)
require.ElementsMatch(t, []string{proton.AllMailLabel, proton.AllDraftsLabel, proton.DraftsLabel}, draft.LabelIDs)
})
})
}
func TestServer_UpdateDraft(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
require.NoError(t, err)
// Create the draft.
draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{
Message: proton.DraftTemplate{
Subject: "My subject",
Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "recipient@example.com"}},
},
})
require.NoError(t, err)
require.Equal(t, addr[0].ID, draft.AddressID)
require.Equal(t, "My subject", draft.Subject)
require.Equal(t, &mail.Address{Address: "user@" + s.GetDomain()}, draft.Sender)
require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, draft.ToList)
// Create an event stream to watch for an update event.
fromEventID, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID)
// Update the draft subject/to-list.
msg, err := c.UpdateDraft(ctx, draft.ID, addrKRs[addr[0].ID], proton.UpdateDraftReq{
Message: proton.DraftTemplate{
Subject: "Edited subject",
ToList: []*mail.Address{{Address: "edited@example.com"}},
},
})
require.NoError(t, err)
require.Equal(t, "Edited subject", msg.Subject)
// We should eventually get an update event.
require.Eventually(t, func() bool {
event := <-eventCh
if len(event.Messages) < 1 {
return false
}
if event.Messages[0].ID != draft.ID {
return false
}
if event.Messages[0].Action != proton.EventUpdate {
return false
}
require.Equal(t, draft.ID, event.Messages[0].ID)
require.Equal(t, "Edited subject", event.Messages[0].Message.Subject)
require.Equal(t, []*mail.Address{{Address: "edited@example.com"}}, event.Messages[0].Message.ToList)
return true
}, 5*time.Second, time.Millisecond*100)
})
})
}
func TestServer_SendMessage(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
require.NoError(t, err)
draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{
Message: proton.DraftTemplate{
Subject: "My subject",
Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "recipient@example.com"}},
},
})
require.NoError(t, err)
sent, err := c.SendDraft(ctx, draft.ID, proton.SendDraftReq{})
require.NoError(t, err)
require.Equal(t, draft.ID, sent.ID)
require.Equal(t, addr[0].ID, sent.AddressID)
require.Equal(t, "My subject", sent.Subject)
require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, sent.ToList)
require.Contains(t, sent.LabelIDs, proton.SentLabel)
})
})
}
func TestServer_AuthDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
require.NoError(t, c.AuthDelete(ctx))
})
})
}
func TestServer_ForceUpgrade(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := New()
defer s.Close()
s.SetMinAppVersion(semver.MustParse("1.0.0"))
if _, _, err := s.CreateUser("user", []byte("pass")); err != nil {
t.Fatal(err)
}
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithAppVersion("proton_0.9.0"),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
var called bool
m.AddErrorHandler(proton.AppVersionBadCode, func() {
called = true
})
if _, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")); err == nil {
t.Fatal(err)
}
require.True(t, called)
}
func TestServer_Import(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
require.NoError(t, err)
res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, 1)
require.NoError(t, err)
require.Len(t, res, 1)
require.Equal(t, proton.SuccessCode, res[0].Code)
message, err := c.GetMessage(ctx, res[0].MessageID)
require.NoError(t, err)
dec, err := message.Decrypt(addrKRs[message.AddressID])
require.NoError(t, err)
require.NotEmpty(t, dec)
})
})
}
func TestServer_Labels(t *testing.T) {
type add string
type rem string
tests := []struct {
name string
flags proton.MessageFlag
actions []any
wantLabelIDs []string
wantError bool
}{
{
name: "received flag, no actions",
flags: proton.MessageFlagReceived,
wantLabelIDs: []string{proton.AllMailLabel},
},
{
name: "sent flag, no actions",
flags: proton.MessageFlagSent,
wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add inbox",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.InboxLabel)},
wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel},
},
{
name: "sent flag, add sent",
flags: proton.MessageFlagSent,
actions: []any{add(proton.SentLabel)},
wantLabelIDs: []string{proton.SentLabel, proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add inbox then add archive",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.InboxLabel), add(proton.ArchiveLabel)},
wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel},
},
{
name: "sent flag, add sent then add archive",
flags: proton.MessageFlagSent,
actions: []any{add(proton.SentLabel), add(proton.ArchiveLabel)},
wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add inbox then remove inbox",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.InboxLabel), rem(proton.InboxLabel)},
wantLabelIDs: []string{proton.AllMailLabel},
},
{
name: "sent flag, add sent then remove sent",
flags: proton.MessageFlagSent,
actions: []any{add(proton.SentLabel), rem(proton.SentLabel)},
wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add inbox then remove archive",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.InboxLabel), rem(proton.ArchiveLabel)},
wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel},
},
{
name: "sent flag, add sent then remove archive",
flags: proton.MessageFlagSent,
actions: []any{add(proton.SentLabel), rem(proton.ArchiveLabel)},
wantLabelIDs: []string{proton.SentLabel, proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add inbox then remove inbox then add archive",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.InboxLabel), rem(proton.InboxLabel), add(proton.ArchiveLabel)},
wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel},
},
{
name: "sent flag, add sent then remove sent then add archive",
flags: proton.MessageFlagSent,
actions: []any{add(proton.SentLabel), rem(proton.SentLabel), add(proton.ArchiveLabel)},
wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add starred",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.StarredLabel)},
wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel},
},
{
name: "sent flag, add starred",
flags: proton.MessageFlagSent,
actions: []any{add(proton.StarredLabel)},
wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add inbox, add starred, remove inbox",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.InboxLabel), add(proton.StarredLabel), rem(proton.InboxLabel)},
wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel},
},
{
name: "sent flag, add sent, add starred, remove sent",
flags: proton.MessageFlagSent,
actions: []any{add(proton.SentLabel), add(proton.StarredLabel), rem(proton.SentLabel)},
wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel},
},
{
name: "received flag, add trash, remove trash",
flags: proton.MessageFlagReceived,
actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)},
wantLabelIDs: []string{proton.AllMailLabel},
},
{
name: "sent flag, add trash, remove trash",
flags: proton.MessageFlagSent,
actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)},
wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel},
},
}
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, tt.flags, 1)
require.True(t, (func() error {
for _, action := range tt.actions {
switch action := action.(type) {
case add:
if err := c.LabelMessages(ctx, []string{res[0].MessageID}, string(action)); err != nil {
return err
}
case rem:
if err := c.UnlabelMessages(ctx, []string{res[0].MessageID}, string(action)); err != nil {
return err
}
}
}
return nil
}() != nil) == tt.wantError)
message, err := c.GetMessage(ctx, res[0].MessageID)
require.NoError(t, err)
// The message should be in the correct labels.
require.ElementsMatch(t, tt.wantLabelIDs, message.LabelIDs)
// The flags should be preserved after import.
require.True(t, message.Flags&tt.flags == tt.flags)
})
}
})
})
}
func TestServer_Import_FlagsAndLabels(t *testing.T) {
tests := []struct {
name string
labelIDs []string
flags proton.MessageFlag
wantLabelIDs []string
wantError bool
}{
{
name: "received flag --> no label",
flags: proton.MessageFlagReceived,
wantLabelIDs: []string{proton.AllMailLabel},
},
{
name: "received flag --> inbox",
labelIDs: []string{proton.InboxLabel},
flags: proton.MessageFlagReceived,
wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel},
},
{
name: "sent flag --> sent",
labelIDs: []string{proton.SentLabel},
flags: proton.MessageFlagSent,
wantLabelIDs: []string{proton.SentLabel, proton.AllSentLabel, proton.AllMailLabel},
},
{
name: "received flag --> sent",
labelIDs: []string{proton.SentLabel},
flags: proton.MessageFlagReceived,
wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel},
},
{
name: "sent flag --> inbox",
labelIDs: []string{proton.InboxLabel},
flags: proton.MessageFlagSent,
wantLabelIDs: []string{proton.SentLabel, proton.AllSentLabel, proton.AllMailLabel},
},
{
name: "no flag --> drafts",
labelIDs: []string{proton.DraftsLabel},
wantLabelIDs: []string{proton.DraftsLabel, proton.AllDraftsLabel, proton.AllMailLabel},
},
{
name: "forbidden: received flag --> All Mail",
labelIDs: []string{proton.AllMailLabel},
flags: proton.MessageFlagReceived,
wantError: true,
},
{
name: "forbidden: sent flag --> All Mail",
labelIDs: []string{proton.AllMailLabel},
flags: proton.MessageFlagSent,
wantError: true,
},
{
name: "forbidden: received flag --> inbox and all mail",
labelIDs: []string{proton.InboxLabel, proton.AllMailLabel},
flags: proton.MessageFlagReceived,
wantError: true,
},
{
name: "forbidden: sent flag --> sent and all mail",
labelIDs: []string{proton.SentLabel, proton.AllMailLabel},
flags: proton.MessageFlagSent,
wantError: true,
},
{
name: "forbidden: received flag --> inbox and sent",
labelIDs: []string{proton.InboxLabel, proton.SentLabel},
flags: proton.MessageFlagReceived,
wantError: true,
},
{
name: "forbidden: sent flag --> inbox and sent",
labelIDs: []string{proton.InboxLabel, proton.SentLabel},
flags: proton.MessageFlagSent,
wantError: true,
},
{
name: "forbidden: received flag --> inbox and archive",
labelIDs: []string{proton.InboxLabel, proton.ArchiveLabel},
flags: proton.MessageFlagReceived,
wantError: true,
},
}
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, err := stream.Collect(ctx, c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), []proton.ImportReq{{
Metadata: proton.ImportMetadata{
AddressID: addr[0].ID,
Flags: tt.flags,
LabelIDs: tt.labelIDs,
},
Message: newMessageLiteral("sender@example.com", "recipient@example.com"),
}}...))
if tt.wantError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, proton.SuccessCode, res[0].Code)
message, err := c.GetMessage(ctx, res[0].MessageID)
require.NoError(t, err)
// The message should be in the correct labels.
require.ElementsMatch(t, tt.wantLabelIDs, message.LabelIDs)
// The flags should be preserved after import.
require.True(t, message.Flags&tt.flags == tt.flags)
}
})
}
})
})
}
func TestServer_PublicKeys(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
if _, _, err := s.CreateUser("other", []byte("pass")); err != nil {
t.Fatal(err)
}
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
intKeys, intType, err := c.GetPublicKeys(ctx, "other@"+s.GetDomain())
require.NoError(t, err)
require.Equal(t, proton.RecipientTypeInternal, intType)
require.Len(t, intKeys, 1)
extKeys, extType, err := c.GetPublicKeys(ctx, "other@example.com")
require.NoError(t, err)
require.Equal(t, proton.RecipientTypeExternal, extType)
require.Len(t, extKeys, 0)
})
})
}
func TestServer_Proxy(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
var calls []Call
s.AddCallWatcher(func(call Call) {
calls = append(calls, call)
})
withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) {
proxy := New(
WithProxyOrigin(s.GetHostURL()),
WithProxyTransport(proton.InsecureTransport()),
)
defer proxy.Close()
m := proton.New(
proton.WithHostURL(proxy.GetProxyURL()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
// Login -- the call should be proxied to the upstream server.
c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer c.Close()
// The results of the call should be correct.
user, err := c.GetUser(ctx)
require.NoError(t, err)
require.Equal(t, "user", user.Name)
})
// Assert that the calls were proxied.
require.Greater(t, len(calls), 0)
})
}
func TestServer_Proxy_Cache(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) {
proxy := New(
WithProxyOrigin(s.GetHostURL()),
WithProxyTransport(proton.InsecureTransport()),
WithAuthCacher(NewAuthCache()),
)
defer proxy.Close()
// Need to skip verifying the server proofs for the proxy cache feature to work!
m := proton.New(
proton.WithHostURL(proxy.GetProxyURL()),
proton.WithTransport(proton.InsecureTransport()),
proton.WithSkipVerifyProofs(),
)
defer m.Close()
// Login 3 times; we should produce 1 unique auth.
require.Len(t, xslices.Unique(iterator.Collect(iterator.Map(iterator.Counter(3), func(int) string {
c, auth, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer c.Close()
return auth.UID
}))), 1)
})
})
}
func TestServer_Proxy_AuthDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) {
proxy := New(
WithProxyOrigin(s.GetHostURL()),
WithProxyTransport(proton.InsecureTransport()),
WithAuthCacher(NewAuthCache()),
)
defer proxy.Close()
// Need to skip verifying the server proofs for the proxy cache feature to work!
m := proton.New(
proton.WithHostURL(proxy.GetProxyURL()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
// Watch for login -- the calls should be proxied.
var login []Call
s.AddCallWatcher(func(call Call) {
login = append(login, call)
})
// Login -- the call should be proxied to the upstream server.
c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer c.Close()
// Assert that the login was proxied.
require.NotEmpty(t, len(login))
// Watch for logout -- logout should not be proxied to the upstream server.
var logout []Call
s.AddCallWatcher(func(call Call) {
logout = append(logout, call)
})
// Logout -- the call should not be proxied to the upstream server.
require.NoError(t, c.AuthDelete(ctx))
// Assert that the logout was not proxied!
require.Empty(t, len(logout))
})
})
}
func TestServer_RealProxy(t *testing.T) {
username := os.Getenv("GO_PROTON_API_TEST_USERNAME")
password := os.Getenv("GO_PROTON_API_TEST_PASSWORD")
if username == "" || password == "" {
t.Skip("skipping test, set the username and password to run")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
proxy := New()
defer proxy.Close()
m := proton.New(
proton.WithHostURL(proxy.GetProxyURL()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
// Login -- the call should be proxied to the upstream server.
c, _, err := m.NewClientWithLogin(ctx, username, []byte(password))
require.NoError(t, err)
defer c.Close()
// The results of the call should be correct.
user, err := c.GetUser(ctx)
require.NoError(t, err)
require.Equal(t, username, user.Name)
}
func TestServer_RealProxy_Cache(t *testing.T) {
username := os.Getenv("GO_PROTON_API_TEST_USERNAME")
password := os.Getenv("GO_PROTON_API_TEST_PASSWORD")
if username == "" || password == "" {
t.Skip("skipping test, set the username and password to run")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
proxy := New(WithAuthCacher(NewAuthCache()))
defer proxy.Close()
m := proton.New(
proton.WithHostURL(proxy.GetProxyURL()),
proton.WithTransport(proton.InsecureTransport()),
proton.WithSkipVerifyProofs(),
)
defer m.Close()
// Login 3 times; we should produce 1 unique auth.
require.Len(t, xslices.Unique(iterator.Collect(iterator.Map(iterator.Counter(3), func(int) string {
c, auth, err := m.NewClientWithLogin(ctx, username, []byte(password))
require.NoError(t, err)
defer c.Close()
return auth.UID
}))), 1)
}
func TestServer_Messages_Fetch(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
ctl := proton.NewNetCtl()
mm := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
defer mm.Close()
cc, _, err := mm.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer cc.Close()
total := countBytesRead(ctl, func() {
res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...))
require.NoError(t, err)
require.NotEmpty(t, res)
})
ctl.SetReadLimit(total / 2)
require.Less(t, countBytesRead(ctl, func() {
res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...))
require.Error(t, err)
require.Empty(t, res)
}), total)
ctl.SetReadLimit(0)
require.Equal(t, countBytesRead(ctl, func() {
res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...))
require.NoError(t, err)
require.NotEmpty(t, res)
}), total)
})
})
}, WithTLS(false))
}
func TestServer_Messages_Status(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
ctl := proton.NewNetCtl()
mm := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
defer mm.Close()
statusCh := make(chan proton.Status, 1)
mm.AddStatusObserver(func(status proton.Status) {
statusCh <- status
})
cc, _, err := mm.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer cc.Close()
total := countBytesRead(ctl, func() {
res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...))
require.NoError(t, err)
require.NotEmpty(t, res)
})
ctl.SetReadLimit(total / 2)
res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...))
require.Error(t, err)
require.Empty(t, res)
require.Equal(t, proton.StatusDown, <-statusCh)
})
})
}, WithTLS(false))
}
func TestServer_Labels_Duplicates(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
req := proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeLabel,
}
label, err := c.CreateLabel(context.Background(), req)
require.NoError(t, err)
require.Equal(t, req.Name, label.Name)
_, err = c.CreateLabel(context.Background(), req)
require.Error(t, err)
})
})
}
func TestServer_Labels_Duplicates_Update(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
label1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeLabel,
})
require.NoError(t, err)
label2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeLabel,
})
require.NoError(t, err)
// Updating label1 with label2's name should fail.
_, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{
Name: label2.Name,
Color: label1.Color,
})
require.Error(t, err)
// Updating label1's color while preserving its name should succeed.
_, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{
Name: label1.Name,
Color: "#f00",
})
require.NoError(t, err)
// Updating label1 with a new name should succeed.
_, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{
Name: uuid.NewString(),
Color: label1.Color,
})
require.NoError(t, err)
})
})
}
func TestServer_Labels_Subfolders(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
ParentID: parent.ID,
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
require.Equal(t, []string{parent.Name, child.Name}, child.Path)
child2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
ParentID: child.ID,
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
require.Equal(t, []string{parent.Name, child.Name, child2.Name}, child2.Path)
})
})
}
func TestServer_Labels_Subfolders_Reassign(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
parent1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
parent2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
// Create a child initially under parent1.
child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
ParentID: parent1.ID,
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
require.Equal(t, []string{parent1.Name, child.Name}, child.Path)
// Reassign the child to parent2.
child2, err := c.UpdateLabel(context.Background(), child.ID, proton.UpdateLabelReq{
Name: child.Name,
Color: child.Color,
ParentID: parent2.ID,
})
require.NoError(t, err)
require.Equal(t, []string{parent2.Name, child.Name}, child2.Path)
// Reassign the child to no parent.
child3, err := c.UpdateLabel(context.Background(), child.ID, proton.UpdateLabelReq{
Name: child2.Name,
Color: child2.Color,
ParentID: "",
})
require.NoError(t, err)
require.Equal(t, []string{child3.Name}, child3.Path)
})
})
}
func TestServer_Labels_Subfolders_DeleteParentWithChildren(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
ParentID: parent.ID,
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
require.Equal(t, []string{parent.Name, child.Name}, child.Path)
other, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(),
Color: "#f66",
Type: proton.LabelTypeFolder,
})
require.NoError(t, err)
// Get labels before.
before, err := c.GetLabels(context.Background(), proton.LabelTypeFolder)
require.NoError(t, err)
// Delete the parent.
require.NoError(t, c.DeleteLabel(context.Background(), parent.ID))
// Get labels after.
after, err := c.GetLabels(context.Background(), proton.LabelTypeFolder)
require.NoError(t, err)
// Both parent and child are deleted.
require.Equal(t, len(before)-2, len(after))
// The only label left is the other one.
require.Equal(t, other.ID, after[0].ID)
})
})
}
func TestServer_AddressCreateDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(context.Background())
require.NoError(t, err)
// Create an address.
alias, err := s.CreateAddress(user.ID, "alias@example.com", []byte("pass"))
require.NoError(t, err)
// The user should have two addresses, both enabled.
{
addr, err := c.GetAddresses(context.Background())
require.NoError(t, err)
require.Len(t, addr, 2)
require.Equal(t, addr[0].Status, proton.AddressStatusEnabled)
require.Equal(t, addr[1].Status, proton.AddressStatusEnabled)
}
// Disable the alias.
require.NoError(t, c.DisableAddress(context.Background(), alias))
// The user should have two addresses, the primary enabled and the alias disabled.
{
addr, err := c.GetAddresses(context.Background())
require.NoError(t, err)
require.Len(t, addr, 2)
require.Equal(t, addr[0].Status, proton.AddressStatusEnabled)
require.Equal(t, addr[1].Status, proton.AddressStatusDisabled)
}
// Delete the alias.
require.NoError(t, c.DeleteAddress(context.Background(), alias))
// The user should have one address, the primary enabled.
{
addr, err := c.GetAddresses(context.Background())
require.NoError(t, err)
require.Len(t, addr, 1)
require.Equal(t, addr[0].Status, proton.AddressStatusEnabled)
}
})
})
}
func TestServer_AddressOrder(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(context.Background())
require.NoError(t, err)
primary, err := c.GetAddresses(context.Background())
require.NoError(t, err)
// Create 3 additional addresses.
addr1, err := s.CreateAddress(user.ID, "addr1@example.com", []byte("pass"))
require.NoError(t, err)
addr2, err := s.CreateAddress(user.ID, "addr2@example.com", []byte("pass"))
require.NoError(t, err)
addr3, err := s.CreateAddress(user.ID, "addr3@example.com", []byte("pass"))
require.NoError(t, err)
addresses, err := c.GetAddresses(context.Background())
require.NoError(t, err)
// Check the order.
require.Equal(t, primary[0].ID, addresses[0].ID)
require.Equal(t, addr1, addresses[1].ID)
require.Equal(t, addr2, addresses[2].ID)
require.Equal(t, addr3, addresses[3].ID)
// Update the order.
require.NoError(t, c.OrderAddresses(ctx, proton.OrderAddressesReq{
AddressIDs: []string{addr3, addr2, addr1, primary[0].ID},
}))
// Check the order.
addresses, err = c.GetAddresses(context.Background())
require.NoError(t, err)
require.Equal(t, addr3, addresses[0].ID)
require.Equal(t, addr2, addresses[1].ID)
require.Equal(t, addr1, addresses[2].ID)
require.Equal(t, primary[0].ID, addresses[3].ID)
})
})
}
func TestServer_Domains(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
domains, err := m.GetDomains(ctx)
require.NoError(t, err)
require.Equal(t, []string{s.GetDomain()}, domains)
})
}
func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := New(opts...)
defer s.Close()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithCookieJar(newTestCookieJar()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
fn(ctx, s, m)
}
func withUser(ctx context.Context, t *testing.T, s *Server, m *proton.Manager, user, pass string, fn func(c *proton.Client)) {
_, _, err := s.CreateUser(user, []byte(pass))
require.NoError(t, err)
c, _, err := m.NewClientWithLogin(ctx, user, []byte(pass))
require.NoError(t, err)
defer c.Close()
fn(c)
}
func withMessages(ctx context.Context, t *testing.T, c *proton.Client, pass string, count int, fn func([]string)) {
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
keyPass, err := salt.SaltForKey([]byte(pass), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
require.NoError(t, err)
fn(xslices.Map(importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, count), func(res proton.ImportRes) string {
return res.MessageID
}))
}
func importMessages(
ctx context.Context,
t *testing.T,
c *proton.Client,
addrID string,
addrKR *crypto.KeyRing,
labelIDs []string,
flags proton.MessageFlag,
count int,
) []proton.ImportRes {
req := iterator.Collect(iterator.Map(iterator.Counter(count), func(int) proton.ImportReq {
return proton.ImportReq{
Metadata: proton.ImportMetadata{
AddressID: addrID,
LabelIDs: labelIDs,
Flags: flags,
Unread: true,
},
Message: newMessageLiteral("sender@example.com", "recipient@example.com"),
}
}))
res, err := stream.Collect(ctx, c.ImportMessages(ctx, addrKR, runtime.NumCPU(), runtime.NumCPU(), req...))
require.NoError(t, err)
return res
}
func countBytesRead(ctl *proton.NetCtl, fn func()) uint64 {
var read uint64
ctl.OnRead(func(b []byte) {
atomic.AddUint64(&read, uint64(len(b)))
})
fn()
return read
}
type testCookieJar struct {
cookies map[string][]*http.Cookie
lock sync.RWMutex
}
func newTestCookieJar() *testCookieJar {
return &testCookieJar{
cookies: make(map[string][]*http.Cookie),
}
}
func (j *testCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
j.lock.Lock()
defer j.lock.Unlock()
j.cookies[u.Host] = cookies
}
func (j *testCookieJar) Cookies(u *url.URL) []*http.Cookie {
j.lock.RLock()
defer j.lock.RUnlock()
return j.cookies[u.Host]
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}
func elementsMatch[T comparable](want, got []T) bool {
if len(want) != len(got) {
return false
}
for _, w := range want {
if !slices.Contains(got, w) {
return false
}
}
return true
}