test(GODT-2181): Handle quark commands in test server

This commit is contained in:
James Houlahan
2022-12-08 21:23:32 +01:00
committed by James
parent 0afe67dc1c
commit a847d9b892
29 changed files with 666 additions and 242 deletions

View File

@@ -44,3 +44,21 @@ func (c *Client) OrderAddresses(ctx context.Context, req OrderAddressesReq) erro
return r.SetBody(req).Put("/core/v4/addresses/order") return r.SetBody(req).Put("/core/v4/addresses/order")
}) })
} }
func (c *Client) EnableAddress(ctx context.Context, addressID string) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Put("/core/v4/addresses/" + addressID + "/enable")
})
}
func (c *Client) DisableAddress(ctx context.Context, addressID string) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Put("/core/v4/addresses/" + addressID + "/disable")
})
}
func (c *Client) DeleteAddress(ctx context.Context, addressID string) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/core/v4/addresses/" + addressID)
})
}

View File

@@ -1,33 +0,0 @@
package proton
import "sync/atomic"
type atomicUint64 struct {
v uint64
}
func (x *atomicUint64) Load() uint64 { return atomic.LoadUint64(&x.v) }
func (x *atomicUint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) }
func (x *atomicUint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) }
func (x *atomicUint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) }
type atomicBool struct {
v uint32
}
func (x *atomicBool) Load() bool { return atomic.LoadUint32(&x.v) != 0 }
func (x *atomicBool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) }
func (x *atomicBool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 }
func b32(b bool) uint32 {
if b {
return 1
}
return 0
}

View File

@@ -16,7 +16,7 @@ func TestAuth(t *testing.T) {
s := server.New() s := server.New()
defer s.Close() defer s.Close()
_, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) _, _, err := s.CreateUser("user", []byte("password"))
require.NoError(t, err) require.NoError(t, err)
m := proton.New( m := proton.New(
@@ -61,7 +61,7 @@ func TestAuth_Refresh(t *testing.T) {
defer s.Close() defer s.Close()
// Create a user on the server. // Create a user on the server.
userID, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) userID, _, err := s.CreateUser("user", []byte("password"))
require.NoError(t, err) require.NoError(t, err)
// The auth is valid for 4 seconds. // The auth is valid for 4 seconds.
@@ -106,7 +106,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
defer s.Close() defer s.Close()
// Create a user on the server. // Create a user on the server.
userID, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) userID, _, err := s.CreateUser("user", []byte("password"))
require.NoError(t, err) require.NoError(t, err)
// The auth is valid for 4 seconds. // The auth is valid for 4 seconds.
@@ -149,7 +149,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
defer s.Close() defer s.Close()
// Create a user on the server. // Create a user on the server.
userID, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) userID, _, err := s.CreateUser("user", []byte("password"))
require.NoError(t, err) require.NoError(t, err)
m := proton.New( m := proton.New(

View File

@@ -22,7 +22,7 @@ func TestEventStreamer(t *testing.T) {
proton.WithTransport(proton.InsecureTransport()), proton.WithTransport(proton.InsecureTransport()),
) )
_, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) _, _, err := s.CreateUser("user", []byte("password"))
require.NoError(t, err) require.NoError(t, err)
c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password")) c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))

5
go.mod
View File

@@ -8,18 +8,19 @@ require (
github.com/ProtonMail/go-crypto v0.0.0-20220824120805-4b6e5c587895 github.com/ProtonMail/go-crypto v0.0.0-20220824120805-4b6e5c587895
github.com/ProtonMail/go-srp v0.0.5 github.com/ProtonMail/go-srp v0.0.5
github.com/ProtonMail/gopenpgp/v2 v2.4.10 github.com/ProtonMail/gopenpgp/v2 v2.4.10
github.com/PuerkitoBio/goquery v1.8.0
github.com/bradenaw/juniper v0.8.0 github.com/bradenaw/juniper v0.8.0
github.com/emersion/go-message v0.16.0 github.com/emersion/go-message v0.16.0
github.com/emersion/go-vcard v0.0.0-20220507122617-d4056df0ec4a github.com/emersion/go-vcard v0.0.0-20220507122617-d4056df0ec4a
github.com/gin-gonic/gin v1.8.1 github.com/gin-gonic/gin v1.8.1
github.com/go-resty/resty/v2 v2.7.0 github.com/go-resty/resty/v2 v2.7.0
github.com/google/go-cmp v0.5.8
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
github.com/urfave/cli/v2 v2.20.3 github.com/urfave/cli/v2 v2.20.3
go.uber.org/goleak v1.1.12 go.uber.org/goleak v1.1.12
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 golang.org/x/exp v0.0.0-20220827204233-334a2380cb91
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b
google.golang.org/grpc v1.50.1 google.golang.org/grpc v1.50.1
google.golang.org/protobuf v1.28.0 google.golang.org/protobuf v1.28.0
) )
@@ -27,6 +28,7 @@ require (
require ( require (
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect
github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f // indirect github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f // indirect
github.com/andybalholm/cascadia v1.3.1 // indirect
github.com/cloudflare/circl v1.2.0 // indirect github.com/cloudflare/circl v1.2.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect github.com/cronokirby/saferith v0.33.0 // indirect
@@ -50,7 +52,6 @@ require (
github.com/ugorji/go/codec v1.2.7 // indirect github.com/ugorji/go/codec v1.2.7 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect

6
go.sum
View File

@@ -19,6 +19,10 @@ github.com/ProtonMail/go-srp v0.0.5 h1:xhUioxZgDbCnpo9JehyFhwwsn9JLWkUGfB0oiKXgi
github.com/ProtonMail/go-srp v0.0.5/go.mod h1:06iYHtLXW8vjLtccWj++x3MKy65sIT8yZd7nrJF49rs= github.com/ProtonMail/go-srp v0.0.5/go.mod h1:06iYHtLXW8vjLtccWj++x3MKy65sIT8yZd7nrJF49rs=
github.com/ProtonMail/gopenpgp/v2 v2.4.10 h1:EYgkxzwmQvsa6kxxkgP1AwzkFqKHscF2UINxaSn6rdI= github.com/ProtonMail/gopenpgp/v2 v2.4.10 h1:EYgkxzwmQvsa6kxxkgP1AwzkFqKHscF2UINxaSn6rdI=
github.com/ProtonMail/gopenpgp/v2 v2.4.10/go.mod h1:CTRA7/toc/4DxDy5Du4hPDnIZnJvXSeQ8LsRTOUJoyc= github.com/ProtonMail/gopenpgp/v2 v2.4.10/go.mod h1:CTRA7/toc/4DxDy5Du4hPDnIZnJvXSeQ8LsRTOUJoyc=
github.com/PuerkitoBio/goquery v1.8.0 h1:PJTF7AmFCFKk1N6V6jmKfrNH9tV5pNE6lZMkG0gta/U=
github.com/PuerkitoBio/goquery v1.8.0/go.mod h1:ypIiRMtY7COPGk+I/YbZLbxsxn9g5ejnI2HSMtkjZvI=
github.com/andybalholm/cascadia v1.3.1 h1:nhxRkql1kdYCc8Snf7D5/D3spOX+dBgjA6u8x004T2c=
github.com/andybalholm/cascadia v1.3.1/go.mod h1:R4bJ1UQfqADjvDa4P6HZHLh/3OxWWEqc0Sk8XGwHqvA=
github.com/bradenaw/juniper v0.8.0 h1:sdanLNdJbLjcLj993VYIwUHlUVkLzvgiD/x9O7cvvxk= github.com/bradenaw/juniper v0.8.0 h1:sdanLNdJbLjcLj993VYIwUHlUVkLzvgiD/x9O7cvvxk=
github.com/bradenaw/juniper v0.8.0/go.mod h1:Z2B7aJlQ7xbfWsnMLROj5t/5FQ94/MkIdKC30J4WvzI= github.com/bradenaw/juniper v0.8.0/go.mod h1:Z2B7aJlQ7xbfWsnMLROj5t/5FQ94/MkIdKC30J4WvzI=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
@@ -79,7 +83,6 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -174,6 +177,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210916014120-12bc252f5db8/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY=

View File

@@ -41,7 +41,7 @@ func createTestMessages(t *testing.T, c *proton.Client, pass string, count int)
Flags: proton.MessageFlagReceived, Flags: proton.MessageFlagReceived,
Unread: true, Unread: true,
}, },
Message: []byte(fmt.Sprintf("From: sender@pm.me\r\nReceiver: recipient@pm.me\r\nSubject: %v\r\n\r\nHello World!", uuid.New())), Message: []byte(fmt.Sprintf("From: sender@example.com\r\nReceiver: recipient@example.com\r\nSubject: %v\r\n\r\nHello World!", uuid.New())),
} }
})) }))

View File

@@ -1,10 +1,15 @@
package proton package proton
import ( import (
"bytes"
"context" "context"
"strings" "strings"
"github.com/PuerkitoBio/goquery"
"golang.org/x/net/html"
) )
// Quark runs a quark command.
func (m *Manager) Quark(ctx context.Context, command string, args ...string) error { func (m *Manager) Quark(ctx context.Context, command string, args ...string) error {
if _, err := m.r(ctx).SetQueryParam("strInput", strings.Join(args, " ")).Get("/internal/quark/" + command); err != nil { if _, err := m.r(ctx).SetQueryParam("strInput", strings.Join(args, " ")).Get("/internal/quark/" + command); err != nil {
return err return err
@@ -12,3 +17,18 @@ func (m *Manager) Quark(ctx context.Context, command string, args ...string) err
return nil return nil
} }
// QuarkRes is the same as Quark, but returns the content extracted from the response body.
func (m *Manager) QuarkRes(ctx context.Context, command string, args ...string) ([]byte, error) {
res, err := m.r(ctx).SetQueryParam("strInput", strings.Join(args, " ")).Get("/internal/quark/" + command)
if err != nil {
return nil, err
}
doc, err := html.Parse(bytes.NewReader(res.Body()))
if err != nil {
return nil, err
}
return []byte(goquery.NewDocumentFromNode(doc).Find(".content").Text()), nil
}

View File

@@ -159,7 +159,7 @@ func TestStatus_NoReadExistingConn(t *testing.T) {
s := server.New() s := server.New()
defer s.Close() defer s.Close()
_, _, err := s.CreateUser("user", "user@pm.me", []byte("pass")) _, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
netCtl := proton.NewNetCtl() netCtl := proton.NewNetCtl()
@@ -197,7 +197,7 @@ func TestStatus_NoWriteExistingConn(t *testing.T) {
s := server.New() s := server.New()
defer s.Close() defer s.Close()
_, _, err := s.CreateUser("user", "user@pm.me", []byte("pass")) _, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
netCtl := proton.NewNetCtl() netCtl := proton.NewNetCtl()

View File

@@ -51,7 +51,7 @@ func TestAuthRefresh(t *testing.T) {
s := server.New() s := server.New()
defer s.Close() defer s.Close()
_, _, err := s.CreateUser("user", "email@pm.me", []byte("pass")) _, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
m := proton.New( m := proton.New(

View File

@@ -4,7 +4,6 @@ import (
"net/http" "net/http"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xslices"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
@@ -25,20 +24,45 @@ func (s *Server) handleGetAddresses() gin.HandlerFunc {
func (s *Server) handleGetAddress() gin.HandlerFunc { func (s *Server) handleGetAddress() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
addresses, err := s.b.GetAddresses(c.GetString("UserID")) address, err := s.b.GetAddress(c.GetString("UserID"), c.Param("addressID"))
if err != nil { if err != nil {
c.AbortWithStatus(http.StatusInternalServerError) c.AbortWithStatus(http.StatusInternalServerError)
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"Address": addresses[xslices.IndexFunc(addresses, func(address proton.Address) bool { "Address": address,
return address.ID == c.Param("addressID")
})],
}) })
} }
} }
func (s *Server) handlePutAddressEnable() gin.HandlerFunc {
return func(c *gin.Context) {
if err := s.b.EnableAddress(c.GetString("UserID"), c.Param("addressID")); err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
}
}
func (s *Server) handlePutAddressDisable() gin.HandlerFunc {
return func(c *gin.Context) {
if err := s.b.DisableAddress(c.GetString("UserID"), c.Param("addressID")); err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
}
}
func (s *Server) handleDeleteAddress() gin.HandlerFunc {
return func(c *gin.Context) {
if err := s.b.DeleteAddress(c.GetString("UserID"), c.Param("addressID")); err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
}
}
func (s *Server) handlePutAddressesOrder() gin.HandlerFunc { func (s *Server) handlePutAddressesOrder() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var req proton.OrderAddressesReq var req proton.OrderAddressesReq

View File

@@ -10,6 +10,7 @@ type address struct {
addrID string addrID string
email string email string
order int order int
status proton.AddressStatus
keys []key keys []key
} }
@@ -20,7 +21,7 @@ func (add *address) toAddress() proton.Address {
Send: true, Send: true,
Receive: true, Receive: true,
Status: proton.AddressStatusEnabled, Status: add.status,
Order: add.order, Order: add.order,
DisplayName: add.email, DisplayName: add.email,

View File

@@ -50,6 +50,16 @@ func (b *Backend) GetAddressID(email string) (string, error) {
}) })
} }
func (b *Backend) GetAddress(userID, addrID string) (proton.Address, error) {
return withAcc(b, userID, func(acc *account) (proton.Address, error) {
if addr, ok := acc.addresses[addrID]; ok {
return addr.toAddress(), nil
}
return proton.Address{}, errors.New("no such address")
})
}
func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) { func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Address, error) { return withAcc(b, userID, func(acc *account) ([]proton.Address, error) {
return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address { return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address {
@@ -58,6 +68,55 @@ func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) {
}) })
} }
func (b *Backend) EnableAddress(userID, addrID string) error {
return b.withAcc(userID, func(acc *account) error {
acc.addresses[addrID].status = proton.AddressStatusEnabled
updateID, err := b.newUpdate(&addressUpdated{addressID: addrID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) DisableAddress(userID, addrID string) error {
return b.withAcc(userID, func(acc *account) error {
acc.addresses[addrID].status = proton.AddressStatusDisabled
updateID, err := b.newUpdate(&addressUpdated{addressID: addrID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) DeleteAddress(userID, addrID string) error {
return b.withAcc(userID, func(acc *account) error {
if acc.addresses[addrID].status != proton.AddressStatusDisabled {
return errors.New("address is not disabled")
}
delete(acc.addresses, addrID)
updateID, err := b.newUpdate(&addressDeleted{addressID: addrID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) SetAddressOrder(userID string, addrIDs []string) error { func (b *Backend) SetAddressOrder(userID string, addrIDs []string) error {
return b.withAcc(userID, func(acc *account) error { return b.withAcc(userID, func(acc *account) error {
for i, addrID := range addrIDs { for i, addrID := range addrIDs {

View File

@@ -16,6 +16,8 @@ import (
) )
type Backend struct { type Backend struct {
domain string
accounts map[string]*account accounts map[string]*account
accLock sync.RWMutex accLock sync.RWMutex
@@ -40,8 +42,9 @@ type Backend struct {
authLife time.Duration authLife time.Duration
} }
func New(authLife time.Duration) *Backend { func New(authLife time.Duration, domain string) *Backend {
return &Backend{ return &Backend{
domain: domain,
accounts: make(map[string]*account), accounts: make(map[string]*account),
attachments: make(map[string]*attachment), attachments: make(map[string]*attachment),
attData: make(map[string][]byte), attData: make(map[string][]byte),
@@ -49,8 +52,7 @@ func New(authLife time.Duration) *Backend {
labels: make(map[string]*label), labels: make(map[string]*label),
updates: make(map[ID]update), updates: make(map[ID]update),
srp: make(map[string]*srp.Server), srp: make(map[string]*srp.Server),
authLife: authLife,
authLife: authLife,
} }
} }
@@ -190,31 +192,42 @@ func (b *Backend) RemoveUserKey(userID, keyID string) error {
return nil return nil
} }
func (b *Backend) CreateAddress(userID, email string, password []byte) (string, error) { func (b *Backend) CreateAddress(userID, email string, password []byte, withKey bool) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) {
token, err := crypto.RandomToken(32) var keys []key
if err != nil {
return "", err
}
armKey, err := GenerateKey(acc.username, email, token, "rsa", 2048) if withKey {
if err != nil { token, err := crypto.RandomToken(32)
return "", err if err != nil {
} return "", err
}
passphrase, err := hashPassword([]byte(password), acc.salt) armKey, err := GenerateKey(acc.username, email, token, "rsa", 2048)
if err != nil { if err != nil {
return "", err return "", err
} }
userKR, err := acc.keys[0].unlock(passphrase) passphrase, err := hashPassword([]byte(password), acc.salt)
if err != nil { if err != nil {
return "", err return "", err
} }
encToken, sigToken, err := encryptWithSignature(userKR, token) userKR, err := acc.keys[0].unlock(passphrase)
if err != nil { if err != nil {
return "", err return "", err
}
encToken, sigToken, err := encryptWithSignature(userKR, token)
if err != nil {
return "", err
}
keys = append(keys, key{
keyID: uuid.NewString(),
key: armKey,
tok: encToken,
sig: sigToken,
})
} }
addressID := uuid.NewString() addressID := uuid.NewString()
@@ -223,12 +236,8 @@ func (b *Backend) CreateAddress(userID, email string, password []byte) (string,
addrID: addressID, addrID: addressID,
email: email, email: email,
order: len(acc.addresses) + 1, order: len(acc.addresses) + 1,
keys: []key{{ status: proton.AddressStatusEnabled,
keyID: uuid.NewString(), keys: keys,
key: armKey,
tok: encToken,
sig: sigToken,
}},
} }
updateID, err := b.newUpdate(&addressCreated{addressID: addressID}) updateID, err := b.newUpdate(&addressCreated{addressID: addressID})

98
server/backend/quark.go Normal file
View File

@@ -0,0 +1,98 @@
package backend
import (
"flag"
"fmt"
"github.com/ProtonMail/go-proton-api"
)
func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) {
switch command {
case "encryption:id":
return s.quarkEncryptionID(args...)
case "user:create":
return s.quarkUserCreate(args...)
case "user:create:address":
return s.quarkUserCreateAddress(args...)
default:
return nil, fmt.Errorf("unknown command: %s", command)
}
}
func (s *Backend) quarkEncryptionID(args ...string) (string, error) {
fs := flag.NewFlagSet("encryption:id", flag.ContinueOnError)
// Required arguments.
// arg0: value
decrypt := fs.Bool("decrypt", false, "decrypt the given encrypted ID")
if err := fs.Parse(args); err != nil {
return "", err
}
// TODO: Encrypt/decrypt are currently no-op.
if *decrypt {
return fs.Arg(0), nil
} else {
return fs.Arg(0), nil
}
}
func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) {
fs := flag.NewFlagSet("user:create", flag.ContinueOnError)
// Required arguments.
name := fs.String("name", "", "new user's name")
pass := fs.String("password", "", "new user's password")
// Optional arguments.
newAddr := fs.Bool("create-address", false, "create the user's default address, will not automatically setup the address key")
genKeys := fs.String("gen-keys", "", "generate new address keys for the user")
if err := fs.Parse(args); err != nil {
return proton.User{}, err
}
userID, err := s.CreateUser(*name, []byte(*pass))
if err != nil {
return proton.User{}, fmt.Errorf("failed to create user: %w", err)
}
// TODO: Create keys of different types (we always use RSA2048).
if *newAddr || *genKeys != "" {
if _, err := s.CreateAddress(userID, *name+"@"+s.domain, []byte(*pass), *genKeys != ""); err != nil {
return proton.User{}, fmt.Errorf("failed to create address with keys: %w", err)
}
}
return s.GetUser(userID)
}
func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error) {
fs := flag.NewFlagSet("user:create:address", flag.ContinueOnError)
// Required arguments.
// arg0: userID
// arg1: password
// arg2: email
// Optional arguments.
genKeys := fs.String("gen-keys", "", "generate new address keys for the user")
if err := fs.Parse(args); err != nil {
return proton.Address{}, err
}
// TODO: Create keys of different types (we always use RSA2048).
addrID, err := s.CreateAddress(fs.Arg(0), fs.Arg(2), []byte(fs.Arg(1)), *genKeys != "")
if err != nil {
return proton.Address{}, fmt.Errorf("failed to create address with keys: %w", err)
}
return s.GetAddress(fs.Arg(0), addrID)
}

View File

@@ -63,11 +63,6 @@ func main() {
Usage: "username of the account", Usage: "username of the account",
Required: true, Required: true,
}, },
&cli.StringFlag{
Name: "email",
Usage: "email of the account",
Required: true,
},
&cli.StringFlag{ &cli.StringFlag{
Name: "password", Name: "password",
Usage: "password of the account", Usage: "password of the account",
@@ -177,7 +172,6 @@ func createUserAction(c *cli.Context) error {
res, err := client.CreateUser(c.Context, &proto.CreateUserRequest{ res, err := client.CreateUser(c.Context, &proto.CreateUserRequest{
Username: c.String("username"), Username: c.String("username"),
Email: c.String("email"),
Password: []byte(c.String("password")), Password: []byte(c.String("password")),
}) })
if err != nil { if err != nil {

View File

@@ -71,7 +71,7 @@ func (s *service) GetInfo(ctx context.Context, req *proto.GetInfoRequest) (*prot
} }
func (s *service) CreateUser(ctx context.Context, req *proto.CreateUserRequest) (*proto.CreateUserResponse, error) { func (s *service) CreateUser(ctx context.Context, req *proto.CreateUserRequest) (*proto.CreateUserResponse, error) {
userID, addrID, err := s.server.CreateUser(req.Username, req.Email, req.Password) userID, addrID, err := s.server.CreateUser(req.Username, req.Password)
if err != nil { if err != nil {
return nil, err return nil, err
} }

11
server/helper_test.go Normal file
View File

@@ -0,0 +1,11 @@
package server
import (
"fmt"
"github.com/google/uuid"
)
func newMessageLiteral(from, to string) []byte {
return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!", from, to, uuid.New()))
}

22
server/init_test.go Normal file
View File

@@ -0,0 +1,22 @@
package server
import (
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
func init() {
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
if err != nil {
panic(err)
}
backend.GenerateKey = func(_, _ string, passphrase []byte, _ string, _ int) (string, error) {
encKey, err := key.Lock(passphrase)
if err != nil {
return "", err
}
return encKey.Armor()
}
}

View File

@@ -1,17 +1,16 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.28.1 // protoc-gen-go v1.28.0
// protoc v3.21.7 // protoc v3.21.10
// source: server.proto // source: server.proto
package proto package proto
import ( import (
reflect "reflect"
sync "sync"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
) )
const ( const (
@@ -166,7 +165,6 @@ type CreateUserRequest struct {
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
Email string `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"`
Password []byte `protobuf:"bytes,3,opt,name=password,proto3" json:"password,omitempty"` Password []byte `protobuf:"bytes,3,opt,name=password,proto3" json:"password,omitempty"`
} }
@@ -209,13 +207,6 @@ func (x *CreateUserRequest) GetUsername() string {
return "" return ""
} }
func (x *CreateUserRequest) GetEmail() string {
if x != nil {
return x.Email
}
return ""
}
func (x *CreateUserRequest) GetPassword() []byte { func (x *CreateUserRequest) GetPassword() []byte {
if x != nil { if x != nil {
return x.Password return x.Password
@@ -694,80 +685,79 @@ var file_server_proto_rawDesc = []byte{
0x73, 0x74, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x68, 0x6f, 0x73, 0x73, 0x74, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x68, 0x6f, 0x73,
0x74, 0x55, 0x52, 0x4c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c, 0x74, 0x55, 0x52, 0x4c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c,
0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c,
0x22, 0x61, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x22, 0x4b, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d,
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d,
0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20,
0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x44, 0x0a,
0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f,
0x6f, 0x72, 0x64, 0x22, 0x44, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20,
0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x16, 0x0a, 0x06, 0x61,
0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64,
0x44, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x72, 0x49, 0x44, 0x22, 0x2b, 0x0a, 0x11, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65,
0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x2b, 0x0a, 0x11, 0x52, 0x65, 0x76, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72,
0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44,
0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x60, 0x0a, 0x14, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65,
0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16,
0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06,
0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18,
0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x60, 0x0a, 0x14, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08,
0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08,
0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x2f, 0x0a, 0x15, 0x43, 0x72, 0x65, 0x61,
0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05,
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03,
0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x2f,
0x0a, 0x15, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22,
0x46, 0x0a, 0x14, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12,
0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x17, 0x0a, 0x15, 0x52, 0x65, 0x6d, 0x6f, 0x76,
0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
0x22, 0x82, 0x01, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12,
0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12,
0x24, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x52,
0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x2f, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c,
0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07,
0x6c, 0x61, 0x62, 0x65, 0x6c, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c,
0x61, 0x62, 0x65, 0x6c, 0x49, 0x44, 0x2a, 0x22, 0x0a, 0x09, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54,
0x79, 0x70, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x46, 0x4f, 0x4c, 0x44, 0x45, 0x52, 0x10, 0x00, 0x12,
0x09, 0x0a, 0x05, 0x4c, 0x41, 0x42, 0x45, 0x4c, 0x10, 0x01, 0x32, 0xa6, 0x03, 0x0a, 0x06, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x38, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f,
0x12, 0x15, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
0x41, 0x0a, 0x0a, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x12, 0x41, 0x0a, 0x0a, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72,
0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55,
0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73,
0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43,
0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61,
0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28,
0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x46, 0x0a, 0x14, 0x52, 0x65, 0x6d,
0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28,
0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64,
0x0b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x19, 0x2e, 0x70, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x44, 0x22, 0x17, 0x0a, 0x15, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65,
0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x82, 0x01, 0x0a, 0x12, 0x43,
0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x6e, 0x73, 0x65, 0x42, 0x2e, 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x70, 0x72, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28,
0x6f, 0x74, 0x6f, 0x6e, 0x74, 0x65, 0x63, 0x68, 0x2e, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x2f, 0x6c, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d,
0x69, 0x74, 0x65, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a,
0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52,
0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x24, 0x0a, 0x04, 0x74, 0x79, 0x70,
0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22,
0x2f, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x49,
0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x49, 0x44,
0x2a, 0x22, 0x0a, 0x09, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0a, 0x0a,
0x06, 0x46, 0x4f, 0x4c, 0x44, 0x45, 0x52, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x4c, 0x41, 0x42,
0x45, 0x4c, 0x10, 0x01, 0x32, 0xa6, 0x03, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x38, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x15, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66,
0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x0a, 0x43, 0x72, 0x65,
0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e,
0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65,
0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x0a,
0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76,
0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
0x4a, 0x0a, 0x0d, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73,
0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72,
0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x52,
0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65,
0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52,
0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x43, 0x72, 0x65, 0x61, 0x74,
0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43,
0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x1a, 0x1a, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65,
0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a,
0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x50, 0x72, 0x6f, 0x74,
0x6f, 0x6e, 0x4d, 0x61, 0x69, 0x6c, 0x2f, 0x67, 0x6f, 0x2d, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x6e,
0x2d, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@@ -33,7 +33,6 @@ message GetInfoResponse {
message CreateUserRequest { message CreateUserRequest {
string username = 1; string username = 1;
string email = 2;
bytes password = 3; bytes password = 3;
} }

View File

@@ -1,14 +1,13 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.2.0 // - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.7 // - protoc v3.21.10
// source: server.proto // source: server.proto
package proto package proto
import ( import (
context "context" context "context"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes" codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status" status "google.golang.org/grpc/status"

51
server/quark.go Normal file
View File

@@ -0,0 +1,51 @@
package server
import (
"encoding/json"
"html/template"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// TODO: This is a disgusting hack to match the output of the internal quark command.
// They should return JSON instead of HTML!
func (s *Server) handleQuarkCommand() gin.HandlerFunc {
return func(c *gin.Context) {
res, err := s.b.RunQuarkCommand(c.Param("command"), strings.Split(c.Query("strInput"), " ")...)
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
var out string
switch res := res.(type) {
case string:
out = res
default:
b, err := json.MarshalIndent(res, "", " ")
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
out = string(b)
}
tmp, err := template.New("quarkCommand").Parse(`<html><body><div class="content">{{.Content}}</div></body></html>`)
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
if err := tmp.Execute(c.Writer, map[string]string{
"Content": template.HTMLEscapeString(out),
}); err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
}
}

73
server/quark_test.go Normal file
View File

@@ -0,0 +1,73 @@
package server
import (
"context"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/stretchr/testify/require"
)
func TestServer_Quark_CreateUser(t *testing.T) {
withServer(t, func(ctx context.Context, _ *Server, m *proton.Manager) {
// Create two users, one with keys and one without.
require.NoError(t, m.Quark(ctx, "user:create", "--name", "user-no-keys", "--password", "test", "--create-address"))
require.NoError(t, m.Quark(ctx, "user:create", "--name", "user-keys", "--password", "test", "--gen-keys", "rsa2048"))
{
// The address should be created but should have no keys.
c, _, err := m.NewClientWithLogin(ctx, "user-no-keys", []byte("test"))
require.NoError(t, err)
defer c.Close()
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
require.Len(t, addr, 1)
require.Len(t, addr[0].Keys, 0)
}
{
// The address should be created and should have keys.
c, _, err := m.NewClientWithLogin(ctx, "user-keys", []byte("test"))
require.NoError(t, err)
defer c.Close()
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
require.Len(t, addr, 1)
require.Len(t, addr[0].Keys, 1)
}
})
}
func TestServer_Quark_CreateAddress(t *testing.T) {
withServer(t, func(ctx context.Context, _ *Server, m *proton.Manager) {
// Create a user with one address.
require.NoError(t, m.Quark(ctx, "user:create", "--name", "user", "--password", "test", "--gen-keys", "rsa2048"))
// Login to the user.
c, _, err := m.NewClientWithLogin(ctx, "user", []byte("test"))
require.NoError(t, err)
defer c.Close()
// Get the user.
user, err := c.GetUser(ctx)
require.NoError(t, err)
// Initially the user should have one address and it should have keys.
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
require.Len(t, addr, 1)
require.Len(t, addr[0].Keys, 1)
// Create a new address.
require.NoError(t, m.Quark(ctx, "user:create:address", "--gen-keys", "rsa2048", user.ID, "test", "alias@proton.local"))
// Now the user should have two addresses, and they should both have keys.
newAddr, err := c.GetAddresses(ctx)
require.NoError(t, err)
require.Len(t, newAddr, 2)
require.Len(t, newAddr[0].Keys, 1)
require.Len(t, newAddr[1].Keys, 1)
})
}

View File

@@ -48,6 +48,9 @@ func initRouter(s *Server) {
if addresses := core.Group("/addresses"); addresses != nil { if addresses := core.Group("/addresses"); addresses != nil {
addresses.GET("", s.handleGetAddresses()) addresses.GET("", s.handleGetAddresses())
addresses.GET("/:addressID", s.handleGetAddress()) addresses.GET("/:addressID", s.handleGetAddress())
addresses.DELETE("/:addressID", s.handleDeleteAddress())
addresses.PUT("/:addressID/enable", s.handlePutAddressEnable())
addresses.PUT("/:addressID/disable", s.handlePutAddressDisable())
addresses.PUT("/order", s.handlePutAddressesOrder()) addresses.PUT("/order", s.handlePutAddressesOrder())
} }
@@ -114,6 +117,11 @@ func initRouter(s *Server) {
tests.GET("/ping", s.handleGetPing()) tests.GET("/ping", s.handleGetPing())
} }
// Quark routes don't need authentication.
if quark := s.r.Group("/internal/quark"); quark != nil {
quark.GET("/:command", s.handleQuarkCommand())
}
// Proxy any calls to the upstream server. // Proxy any calls to the upstream server.
if proxy := s.r.Group("/proxy"); proxy != nil { if proxy := s.r.Group("/proxy"); proxy != nil {
proxy.Any("/*path", s.handleProxy(proxy.BasePath())) proxy.Any("/*path", s.handleProxy(proxy.BasePath()))

View File

@@ -33,7 +33,10 @@ type Server struct {
callWatchers []callWatcher callWatchers []callWatcher
callWatchersLock sync.RWMutex callWatchersLock sync.RWMutex
// MinAppVersion is the minimum app version that the server will accept. // domain is the test server domain.
domain string
// minAppVersion is the minimum app version that the server will accept.
minAppVersion *semver.Version minAppVersion *semver.Version
// proxyOrigin is the URL of the origin server when the server is a proxy. // proxyOrigin is the URL of the origin server when the server is a proxy.
@@ -56,6 +59,7 @@ func New(opts ...Option) *Server {
return builder.build() return builder.build()
} }
// GetHostURL returns the API root to make calls to.
func (s *Server) GetHostURL() string { func (s *Server) GetHostURL() string {
return s.s.URL return s.s.URL
} }
@@ -65,6 +69,11 @@ func (s *Server) GetProxyURL() string {
return s.s.URL + "/proxy" return s.s.URL + "/proxy"
} }
// GetDomain returns the domain of the server.
func (s *Server) GetDomain() string {
return s.domain
}
func (s *Server) AddCallWatcher(fn func(Call), paths ...string) { func (s *Server) AddCallWatcher(fn func(Call), paths ...string) {
s.callWatchersLock.Lock() s.callWatchersLock.Lock()
defer s.callWatchersLock.Unlock() defer s.callWatchersLock.Unlock()
@@ -72,13 +81,15 @@ func (s *Server) AddCallWatcher(fn func(Call), paths ...string) {
s.callWatchers = append(s.callWatchers, newCallWatcher(fn, paths...)) s.callWatchers = append(s.callWatchers, newCallWatcher(fn, paths...))
} }
func (s *Server) CreateUser(username, email string, password []byte) (string, string, error) { // CreateUser creates a new server user with the given username and password.
// A single address will be created for the user, derived from the username and the server's domain.
func (s *Server) CreateUser(username string, password []byte) (string, string, error) {
userID, err := s.b.CreateUser(username, password) userID, err := s.b.CreateUser(username, password)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
addrID, err := s.b.CreateAddress(userID, email, password) addrID, err := s.b.CreateAddress(userID, username+"@"+s.domain, password, true)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@@ -114,7 +125,7 @@ func (s *Server) RemoveUserKey(userID, keyID string) error {
} }
func (s *Server) CreateAddress(userID, email string, password []byte) (string, error) { func (s *Server) CreateAddress(userID, email string, password []byte) (string, error) {
return s.b.CreateAddress(userID, email, password) return s.b.CreateAddress(userID, email, password, true)
} }
func (s *Server) RemoveAddress(userID, addrID string) error { func (s *Server) RemoveAddress(userID, addrID string) error {

View File

@@ -13,6 +13,7 @@ import (
type serverBuilder struct { type serverBuilder struct {
withTLS bool withTLS bool
domain string
logger io.Writer logger io.Writer
origin string origin string
cacher AuthCacher cacher AuthCacher
@@ -29,6 +30,7 @@ func newServerBuilder() *serverBuilder {
return &serverBuilder{ return &serverBuilder{
withTLS: true, withTLS: true,
domain: "proton.local",
logger: logger, logger: logger,
origin: proton.DefaultHostURL, origin: proton.DefaultHostURL,
} }
@@ -39,8 +41,9 @@ func (builder *serverBuilder) build() *Server {
s := &Server{ s := &Server{
r: gin.New(), r: gin.New(),
b: backend.New(time.Hour), b: backend.New(time.Hour, builder.domain),
domain: builder.domain,
proxyOrigin: builder.origin, proxyOrigin: builder.origin,
authCacher: builder.cacher, authCacher: builder.cacher,
} }
@@ -83,6 +86,21 @@ func (opt withTLS) config(builder *serverBuilder) {
builder.withTLS = opt.withTLS builder.withTLS = opt.withTLS
} }
// withDomain controls the domain of the server.
func WithDomain(domain string) Option {
return &withDomain{
domain: domain,
}
}
type withDomain struct {
domain string
}
func (opt withDomain) config(builder *serverBuilder) {
builder.domain = opt.domain
}
// WithLogger controls where Gin logs to. // WithLogger controls where Gin logs to.
func WithLogger(logger io.Writer) Option { func WithLogger(logger io.Writer) Option {
return &withLogger{ return &withLogger{

View File

@@ -26,13 +26,13 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
func TestServer(t *testing.T) { func TestServer_LoginLogout(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx) user, err := c.GetUser(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "user", user.Name) require.Equal(t, "user", user.Name)
require.Equal(t, "email@pm.me", user.Email) require.Equal(t, "user@"+s.GetDomain(), user.Email)
// Logout from the test API. // Logout from the test API.
require.NoError(t, c.AuthDelete(ctx)) require.NoError(t, c.AuthDelete(ctx))
@@ -45,7 +45,7 @@ func TestServer(t *testing.T) {
func TestServerMulti(t *testing.T) { func TestServerMulti(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
_, _, err := s.CreateUser("user", "email@pm.me", []byte("pass")) _, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
// Create one client. // Create one client.
@@ -120,7 +120,7 @@ func TestServer_Ping(t *testing.T) {
func TestServer_Bool(t *testing.T) { func TestServer_Bool(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1, func([]string) { withMessages(ctx, t, c, "pass", 1, func([]string) {
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err) require.NoError(t, err)
@@ -140,7 +140,7 @@ func TestServer_Bool(t *testing.T) {
func TestServer_Messages(t *testing.T) { func TestServer_Messages(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the messages. // Get the messages.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
@@ -180,7 +180,7 @@ func TestServer_Messages(t *testing.T) {
func TestServer_MessageFilter(t *testing.T) { func TestServer_MessageFilter(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the messages. // Get the messages.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
@@ -210,7 +210,7 @@ func TestServer_MessageFilter(t *testing.T) {
func TestServer_MessageIDs(t *testing.T) { func TestServer_MessageIDs(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 10000, func(wantMessageIDs []string) { withMessages(ctx, t, c, "pass", 10000, func(wantMessageIDs []string) {
allMessageIDs, err := c.GetMessageIDs(ctx, "") allMessageIDs, err := c.GetMessageIDs(ctx, "")
require.NoError(t, err) require.NoError(t, err)
@@ -226,7 +226,7 @@ func TestServer_MessageIDs(t *testing.T) {
func TestServer_MessagesDelete(t *testing.T) { func TestServer_MessagesDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the messages. // Get the messages.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
@@ -255,7 +255,7 @@ func TestServer_MessagesDelete(t *testing.T) {
func TestServer_MessagesDeleteAfterUpdate(t *testing.T) { func TestServer_MessagesDeleteAfterUpdate(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
// Get the initial event ID. // Get the initial event ID.
eventID, err := c.GetLatestEventID(ctx) eventID, err := c.GetLatestEventID(ctx)
@@ -285,7 +285,7 @@ func TestServer_MessagesDeleteAfterUpdate(t *testing.T) {
func TestServer_Events(t *testing.T) { func TestServer_Events(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 3, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 3, func(messageIDs []string) {
// Get the latest event ID to stream from. // Get the latest event ID to stream from.
fromEventID, err := c.GetLatestEventID(ctx) fromEventID, err := c.GetLatestEventID(ctx)
@@ -367,7 +367,7 @@ func TestServer_Events(t *testing.T) {
func TestServer_Events_Multi(t *testing.T) { func TestServer_Events_Multi(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
withUser(ctx, t, s, m, fmt.Sprintf("user%v", i), fmt.Sprintf("email%v@pm.me", i), "pass", func(c *proton.Client) { withUser(ctx, t, s, m, fmt.Sprintf("user%v", i), "pass", func(c *proton.Client) {
latest, err := c.GetLatestEventID(ctx) latest, err := c.GetLatestEventID(ctx)
require.NoError(t, err) require.NoError(t, err)
@@ -393,7 +393,7 @@ func TestServer_Events_Multi(t *testing.T) {
func TestServer_Events_Refresh(t *testing.T) { func TestServer_Events_Refresh(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx) user, err := c.GetUser(ctx)
require.NoError(t, err) require.NoError(t, err)
@@ -417,11 +417,11 @@ func TestServer_Events_Refresh(t *testing.T) {
func TestServer_RevokeUser(t *testing.T) { func TestServer_RevokeUser(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx) user, err := c.GetUser(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "user", user.Name) require.Equal(t, "user", user.Name)
require.Equal(t, "email@pm.me", user.Email) require.Equal(t, "user@"+s.GetDomain(), user.Email)
// Revoke the user's auth. // Revoke the user's auth.
require.NoError(t, s.RevokeUser(user.ID)) require.NoError(t, s.RevokeUser(user.ID))
@@ -434,7 +434,7 @@ func TestServer_RevokeUser(t *testing.T) {
func TestServer_Calls(t *testing.T) { func TestServer_Calls(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
var calls []Call var calls []Call
// Watch calls that are made. // Watch calls that are made.
@@ -466,7 +466,7 @@ func TestServer_Calls(t *testing.T) {
func TestServer_Calls_Status(t *testing.T) { func TestServer_Calls_Status(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
var calls []Call var calls []Call
// Watch calls that are made. // Watch calls that are made.
@@ -492,7 +492,7 @@ func TestServer_Calls_Request(t *testing.T) {
calls = append(calls, call) calls = append(calls, call)
}) })
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(*proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) {
require.Equal( require.Equal(
t, t,
calls[0].RequestBody, calls[0].RequestBody,
@@ -510,7 +510,7 @@ func TestServer_Calls_Response(t *testing.T) {
calls = append(calls, call) calls = append(calls, call)
}) })
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
salts, err := c.GetSalts(ctx) salts, err := c.GetSalts(ctx)
require.NoError(t, err) require.NoError(t, err)
@@ -531,7 +531,7 @@ func TestServer_Calls_Cookies(t *testing.T) {
calls = append(calls, call) calls = append(calls, call)
}) })
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(*proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) {
// The header in the first call's response should set the Session-Id cookie. // The header in the first call's response should set the Session-Id cookie.
resHeader := (&http.Response{Header: calls[len(calls)-2].ResponseHeader}) resHeader := (&http.Response{Header: calls[len(calls)-2].ResponseHeader})
require.Len(t, resHeader.Cookies(), 1) require.Len(t, resHeader.Cookies(), 1)
@@ -572,7 +572,7 @@ func TestServer_Calls_Manager(t *testing.T) {
func TestServer_CreateMessage(t *testing.T) { func TestServer_CreateMessage(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx) user, err := c.GetUser(ctx)
require.NoError(t, err) require.NoError(t, err)
@@ -592,14 +592,15 @@ func TestServer_CreateMessage(t *testing.T) {
Message: proton.DraftTemplate{ Message: proton.DraftTemplate{
Subject: "My subject", Subject: "My subject",
Sender: &mail.Address{Address: addr[0].Email}, Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "recipient@pm.me"}}, ToList: []*mail.Address{{Address: "recipient@example.com"}},
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, addr[0].ID, draft.AddressID) require.Equal(t, addr[0].ID, draft.AddressID)
require.Equal(t, "My subject", draft.Subject) require.Equal(t, "My subject", draft.Subject)
require.Equal(t, &mail.Address{Address: "email@pm.me"}, draft.Sender) require.Equal(t, &mail.Address{Address: "user@" + s.GetDomain()}, draft.Sender)
require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, draft.ToList)
require.ElementsMatch(t, []string{proton.AllMailLabel, proton.AllDraftsLabel, proton.DraftsLabel}, draft.LabelIDs) require.ElementsMatch(t, []string{proton.AllMailLabel, proton.AllDraftsLabel, proton.DraftsLabel}, draft.LabelIDs)
}) })
}) })
@@ -607,7 +608,7 @@ func TestServer_CreateMessage(t *testing.T) {
func TestServer_UpdateDraft(t *testing.T) { func TestServer_UpdateDraft(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx) user, err := c.GetUser(ctx)
require.NoError(t, err) require.NoError(t, err)
@@ -628,13 +629,14 @@ func TestServer_UpdateDraft(t *testing.T) {
Message: proton.DraftTemplate{ Message: proton.DraftTemplate{
Subject: "My subject", Subject: "My subject",
Sender: &mail.Address{Address: addr[0].Email}, Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "recipient@pm.me"}}, ToList: []*mail.Address{{Address: "recipient@example.com"}},
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, addr[0].ID, draft.AddressID) require.Equal(t, addr[0].ID, draft.AddressID)
require.Equal(t, "My subject", draft.Subject) require.Equal(t, "My subject", draft.Subject)
require.Equal(t, &mail.Address{Address: "email@pm.me"}, draft.Sender) require.Equal(t, &mail.Address{Address: "user@" + s.GetDomain()}, draft.Sender)
require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, draft.ToList)
// Create an event stream to watch for an update event. // Create an event stream to watch for an update event.
fromEventID, err := c.GetLatestEventID(ctx) fromEventID, err := c.GetLatestEventID(ctx)
@@ -646,7 +648,7 @@ func TestServer_UpdateDraft(t *testing.T) {
msg, err := c.UpdateDraft(ctx, draft.ID, addrKRs[addr[0].ID], proton.UpdateDraftReq{ msg, err := c.UpdateDraft(ctx, draft.ID, addrKRs[addr[0].ID], proton.UpdateDraftReq{
Message: proton.DraftTemplate{ Message: proton.DraftTemplate{
Subject: "Edited subject", Subject: "Edited subject",
ToList: []*mail.Address{{Address: "edited@pm.me"}}, ToList: []*mail.Address{{Address: "edited@example.com"}},
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -670,7 +672,7 @@ func TestServer_UpdateDraft(t *testing.T) {
require.Equal(t, draft.ID, event.Messages[0].ID) require.Equal(t, draft.ID, event.Messages[0].ID)
require.Equal(t, "Edited subject", event.Messages[0].Message.Subject) require.Equal(t, "Edited subject", event.Messages[0].Message.Subject)
require.Equal(t, []*mail.Address{{Address: "edited@pm.me"}}, event.Messages[0].Message.ToList) require.Equal(t, []*mail.Address{{Address: "edited@example.com"}}, event.Messages[0].Message.ToList)
return true return true
}, 5*time.Second, time.Millisecond*100) }, 5*time.Second, time.Millisecond*100)
@@ -680,7 +682,7 @@ func TestServer_UpdateDraft(t *testing.T) {
func TestServer_SendMessage(t *testing.T) { func TestServer_SendMessage(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(ctx) user, err := c.GetUser(ctx)
require.NoError(t, err) require.NoError(t, err)
@@ -700,7 +702,7 @@ func TestServer_SendMessage(t *testing.T) {
Message: proton.DraftTemplate{ Message: proton.DraftTemplate{
Subject: "My subject", Subject: "My subject",
Sender: &mail.Address{Address: addr[0].Email}, Sender: &mail.Address{Address: addr[0].Email},
ToList: []*mail.Address{{Address: "recipient@pm.me"}}, ToList: []*mail.Address{{Address: "recipient@example.com"}},
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
@@ -711,6 +713,7 @@ func TestServer_SendMessage(t *testing.T) {
require.Equal(t, draft.ID, sent.ID) require.Equal(t, draft.ID, sent.ID)
require.Equal(t, addr[0].ID, sent.AddressID) require.Equal(t, addr[0].ID, sent.AddressID)
require.Equal(t, "My subject", sent.Subject) require.Equal(t, "My subject", sent.Subject)
require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, sent.ToList)
require.Contains(t, sent.LabelIDs, proton.SentLabel) require.Contains(t, sent.LabelIDs, proton.SentLabel)
}) })
}) })
@@ -718,7 +721,7 @@ func TestServer_SendMessage(t *testing.T) {
func TestServer_AuthDelete(t *testing.T) { func TestServer_AuthDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
require.NoError(t, c.AuthDelete(ctx)) require.NoError(t, c.AuthDelete(ctx))
}) })
}) })
@@ -733,7 +736,7 @@ func TestServer_ForceUpgrade(t *testing.T) {
s.SetMinAppVersion(semver.MustParse("1.0.0")) s.SetMinAppVersion(semver.MustParse("1.0.0"))
if _, _, err := s.CreateUser("user", "email@pm.me", []byte("pass")); err != nil { if _, _, err := s.CreateUser("user", []byte("pass")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -759,7 +762,7 @@ func TestServer_ForceUpgrade(t *testing.T) {
func TestServer_Import(t *testing.T) { func TestServer_Import(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -913,7 +916,7 @@ func TestServer_Labels(t *testing.T) {
} }
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -1055,7 +1058,7 @@ func TestServer_Import_FlagsAndLabels(t *testing.T) {
} }
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -1082,7 +1085,7 @@ func TestServer_Import_FlagsAndLabels(t *testing.T) {
Flags: tt.flags, Flags: tt.flags,
LabelIDs: tt.labelIDs, LabelIDs: tt.labelIDs,
}, },
Message: []byte(fmt.Sprintf("From: sender@pm.me\r\nReceiver: recipient@pm.me\r\nSubject: %v\r\n\r\nHello World!", uuid.New())), Message: newMessageLiteral("sender@example.com", "recipient@example.com"),
}}...)) }}...))
if tt.wantError { if tt.wantError {
require.Error(t, err) require.Error(t, err)
@@ -1107,12 +1110,12 @@ func TestServer_Import_FlagsAndLabels(t *testing.T) {
func TestServer_PublicKeys(t *testing.T) { func TestServer_PublicKeys(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
if _, _, err := s.CreateUser("other", "other@pm.me", []byte("pass")); err != nil { if _, _, err := s.CreateUser("other", []byte("pass")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
intKeys, intType, err := c.GetPublicKeys(ctx, "other@pm.me") intKeys, intType, err := c.GetPublicKeys(ctx, "other@"+s.GetDomain())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, proton.RecipientTypeInternal, intType) require.Equal(t, proton.RecipientTypeInternal, intType)
require.Len(t, intKeys, 1) require.Len(t, intKeys, 1)
@@ -1133,8 +1136,11 @@ func TestServer_Proxy(t *testing.T) {
calls = append(calls, call) calls = append(calls, call)
}) })
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(_ *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) {
proxy := New(WithProxyOrigin(s.GetHostURL())) proxy := New(
WithProxyOrigin(s.GetHostURL()),
WithProxyTransport(proton.InsecureTransport()),
)
defer proxy.Close() defer proxy.Close()
m := proton.New( m := proton.New(
@@ -1161,9 +1167,10 @@ func TestServer_Proxy(t *testing.T) {
func TestServer_Proxy_Cache(t *testing.T) { func TestServer_Proxy_Cache(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(_ *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) {
proxy := New( proxy := New(
WithProxyOrigin(s.GetHostURL()), WithProxyOrigin(s.GetHostURL()),
WithProxyTransport(proton.InsecureTransport()),
WithAuthCacher(NewAuthCache()), WithAuthCacher(NewAuthCache()),
) )
defer proxy.Close() defer proxy.Close()
@@ -1190,9 +1197,10 @@ func TestServer_Proxy_Cache(t *testing.T) {
func TestServer_Proxy_AuthDelete(t *testing.T) { func TestServer_Proxy_AuthDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(_ *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) {
proxy := New( proxy := New(
WithProxyOrigin(s.GetHostURL()), WithProxyOrigin(s.GetHostURL()),
WithProxyTransport(proton.InsecureTransport()),
WithAuthCacher(NewAuthCache()), WithAuthCacher(NewAuthCache()),
) )
defer proxy.Close() defer proxy.Close()
@@ -1299,7 +1307,7 @@ func TestServer_RealProxy_Cache(t *testing.T) {
func TestServer_Messages_Fetch(t *testing.T) { func TestServer_Messages_Fetch(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
ctl := proton.NewNetCtl() ctl := proton.NewNetCtl()
@@ -1341,7 +1349,7 @@ func TestServer_Messages_Fetch(t *testing.T) {
func TestServer_Messages_Status(t *testing.T) { func TestServer_Messages_Status(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) {
ctl := proton.NewNetCtl() ctl := proton.NewNetCtl()
@@ -1381,7 +1389,7 @@ func TestServer_Messages_Status(t *testing.T) {
func TestServer_Labels_Duplicates(t *testing.T) { func TestServer_Labels_Duplicates(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
req := proton.CreateLabelReq{ req := proton.CreateLabelReq{
Name: uuid.NewString(), Name: uuid.NewString(),
Color: "#f66", Color: "#f66",
@@ -1400,7 +1408,7 @@ func TestServer_Labels_Duplicates(t *testing.T) {
func TestServer_Labels_Duplicates_Update(t *testing.T) { func TestServer_Labels_Duplicates_Update(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
label1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ label1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(), Name: uuid.NewString(),
Color: "#f66", Color: "#f66",
@@ -1441,7 +1449,7 @@ func TestServer_Labels_Duplicates_Update(t *testing.T) {
func TestServer_Labels_Subfolders(t *testing.T) { func TestServer_Labels_Subfolders(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(), Name: uuid.NewString(),
Color: "#f66", Color: "#f66",
@@ -1472,7 +1480,7 @@ func TestServer_Labels_Subfolders(t *testing.T) {
func TestServer_Labels_Subfolders_Reassign(t *testing.T) { func TestServer_Labels_Subfolders_Reassign(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
parent1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ parent1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(), Name: uuid.NewString(),
Color: "#f66", Color: "#f66",
@@ -1520,7 +1528,7 @@ func TestServer_Labels_Subfolders_Reassign(t *testing.T) {
func TestServer_Labels_Subfolders_DeleteParentWithChildren(t *testing.T) { func TestServer_Labels_Subfolders_DeleteParentWithChildren(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{
Name: uuid.NewString(), Name: uuid.NewString(),
Color: "#f66", Color: "#f66",
@@ -1564,9 +1572,54 @@ func TestServer_Labels_Subfolders_DeleteParentWithChildren(t *testing.T) {
}) })
} }
func TestServer_AddressCreateDelete(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(context.Background())
require.NoError(t, err)
// Create an address.
alias, err := s.CreateAddress(user.ID, "alias@example.com", []byte("pass"))
require.NoError(t, err)
// The user should have two addresses, both enabled.
{
addr, err := c.GetAddresses(context.Background())
require.NoError(t, err)
require.Len(t, addr, 2)
require.Equal(t, addr[0].Status, proton.AddressStatusEnabled)
require.Equal(t, addr[1].Status, proton.AddressStatusEnabled)
}
// Disable the alias.
require.NoError(t, c.DisableAddress(context.Background(), alias))
// The user should have two addresses, the primary enabled and the alias disabled.
{
addr, err := c.GetAddresses(context.Background())
require.NoError(t, err)
require.Len(t, addr, 2)
require.Equal(t, addr[0].Status, proton.AddressStatusEnabled)
require.Equal(t, addr[1].Status, proton.AddressStatusDisabled)
}
// Delete the alias.
require.NoError(t, c.DeleteAddress(context.Background(), alias))
// The user should have one address, the primary enabled.
{
addr, err := c.GetAddresses(context.Background())
require.NoError(t, err)
require.Len(t, addr, 1)
require.Equal(t, addr[0].Status, proton.AddressStatusEnabled)
}
})
})
}
func TestServer_AddressOrder(t *testing.T) { func TestServer_AddressOrder(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
user, err := c.GetUser(context.Background()) user, err := c.GetUser(context.Background())
require.NoError(t, err) require.NoError(t, err)
@@ -1574,13 +1627,13 @@ func TestServer_AddressOrder(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Create 3 additional addresses. // Create 3 additional addresses.
addr1, err := s.CreateAddress(user.ID, "addr1@pm.me", []byte("pass")) addr1, err := s.CreateAddress(user.ID, "addr1@example.com", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
addr2, err := s.CreateAddress(user.ID, "addr2@pm.me", []byte("pass")) addr2, err := s.CreateAddress(user.ID, "addr2@example.com", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
addr3, err := s.CreateAddress(user.ID, "addr3@pm.me", []byte("pass")) addr3, err := s.CreateAddress(user.ID, "addr3@example.com", []byte("pass"))
require.NoError(t, err) require.NoError(t, err)
addresses, err := c.GetAddresses(context.Background()) addresses, err := c.GetAddresses(context.Background())
@@ -1626,11 +1679,11 @@ func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.
fn(ctx, s, m) fn(ctx, s, m)
} }
func withUser(ctx context.Context, t *testing.T, s *Server, m *proton.Manager, username, email, password string, fn func(c *proton.Client)) { func withUser(ctx context.Context, t *testing.T, s *Server, m *proton.Manager, user, pass string, fn func(c *proton.Client)) {
_, _, err := s.CreateUser(username, email, []byte(password)) _, _, err := s.CreateUser(user, []byte(pass))
require.NoError(t, err) require.NoError(t, err)
c, _, err := m.NewClientWithLogin(ctx, username, []byte(password)) c, _, err := m.NewClientWithLogin(ctx, user, []byte(pass))
require.NoError(t, err) require.NoError(t, err)
defer c.Close() defer c.Close()
@@ -1676,7 +1729,7 @@ func importMessages(
Flags: flags, Flags: flags,
Unread: true, Unread: true,
}, },
Message: []byte(fmt.Sprintf("From: sender@pm.me\r\nReceiver: recipient@pm.me\r\nSubject: %v\r\n\r\nHello World!", uuid.New())), Message: newMessageLiteral("sender@example.com", "recipient@example.com"),
} }
})) }))

View File

@@ -1,6 +0,0 @@
To: recipient@pm.me
From: sender@pm.me
Subject: Test
Content-Type: text/plain; charset=utf-8
Test