mirror of
https://github.com/ProtonMail/go-proton-api.git
synced 2025-12-23 23:57:50 -05:00
fix(IMEX-66): prevent potential double encryption on import
This commit is contained in:
committed by
Atanas Janeshliev
parent
e624a080f7
commit
c9726b8d6e
@@ -19,8 +19,8 @@ const (
|
||||
// maxImportCount is the maximum number of messages that can be imported in a single request.
|
||||
maxImportCount = 10
|
||||
|
||||
// maxImportSize is the maximum total request size permitted for a single import request.
|
||||
maxImportSize = 30 * 1024 * 1024
|
||||
// MaxImportSize is the maximum total request size permitted for a single import request.
|
||||
MaxImportSize = 30 * 1024 * 1024
|
||||
)
|
||||
|
||||
var ErrImportEncrypt = errors.New("failed to encrypt message")
|
||||
@@ -31,23 +31,28 @@ type ImportResStream stream.Stream[ImportRes] // gomock does not support generic
|
||||
func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, workers, buffer int, req ...ImportReq) (ImportResStream, error) {
|
||||
// Encrypt each message.
|
||||
for idx := range req {
|
||||
enc, err := EncryptRFC822(addrKR, req[idx].Message)
|
||||
|
||||
// Encryption might mangle the original message bufer, so use a copy.
|
||||
msgCopy := make([]byte, len(req[idx].Message))
|
||||
copy(msgCopy, req[idx].Message)
|
||||
|
||||
enc, err := EncryptRFC822(addrKR, msgCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w %v: %v", ErrImportEncrypt, idx, err)
|
||||
}
|
||||
|
||||
req[idx].Message = enc
|
||||
req[idx].encryptedMessage = enc
|
||||
}
|
||||
|
||||
// If any of the messages exceed the maximum import size, return an error.
|
||||
if xslices.Any(req, func(req ImportReq) bool { return len(req.Message) > maxImportSize }) {
|
||||
if xslices.Any(req, func(req ImportReq) bool { return len(req.encryptedMessage) > MaxImportSize }) {
|
||||
return nil, ErrImportSizeExceeded
|
||||
}
|
||||
|
||||
return stream.Flatten(parallel.MapStream(
|
||||
ctx,
|
||||
stream.FromIterator(iterator.Slice(ChunkSized(req, maxImportCount, maxImportSize, func(req ImportReq) int {
|
||||
return len(req.Message)
|
||||
stream.FromIterator(iterator.Slice(ChunkSized(req, maxImportCount, MaxImportSize, func(req ImportReq) int {
|
||||
return len(req.encryptedMessage)
|
||||
}))),
|
||||
workers,
|
||||
buffer,
|
||||
|
||||
@@ -8,8 +8,13 @@ import (
|
||||
)
|
||||
|
||||
type ImportReq struct {
|
||||
Metadata ImportMetadata
|
||||
Message []byte
|
||||
Metadata ImportMetadata
|
||||
Message []byte
|
||||
encryptedMessage []byte
|
||||
}
|
||||
|
||||
func (r ImportReq) GetEncryptedMessageLength() int {
|
||||
return len(r.encryptedMessage)
|
||||
}
|
||||
|
||||
type namedImportReq struct {
|
||||
@@ -42,7 +47,7 @@ func buildImportReqFields(req []namedImportReq) ([]*resty.MultipartField, error)
|
||||
Param: req.Name,
|
||||
FileName: req.Name + ".eml",
|
||||
ContentType: string(rfc822.MessageRFC822),
|
||||
Stream: resty.NewByteMultipartStream(append(req.Message, "\r\n"...)),
|
||||
Stream: resty.NewByteMultipartStream(append(req.encryptedMessage, "\r\n"...)),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -13,3 +14,8 @@ func newMessageLiteral(from, to string) []byte {
|
||||
func newMessageLiteralWithSubject(from, to, subject string) []byte {
|
||||
return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!", from, to, subject))
|
||||
}
|
||||
|
||||
func newMessageLiteralWithSubjectAndSize(from, to, subject string, paddingSize int) []byte {
|
||||
padding := strings.Repeat("A", paddingSize)
|
||||
return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!Padding:%s", from, to, subject, padding))
|
||||
}
|
||||
|
||||
@@ -2700,3 +2700,72 @@ func getFullMessages(ctx context.Context,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestServer_Import_MessageSizeExceeded_NoDoubleEncryption(t *testing.T) {
|
||||
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
|
||||
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
user, err := c.GetUser(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr, err := c.GetAddresses(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
salt, err := c.GetSalts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
|
||||
require.NoError(t, err)
|
||||
|
||||
reqs := []proton.ImportReq{
|
||||
{
|
||||
Metadata: proton.ImportMetadata{
|
||||
AddressID: addr[0].ID,
|
||||
LabelIDs: []string{},
|
||||
Flags: proton.MessageFlagReceived,
|
||||
Unread: true,
|
||||
},
|
||||
Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), 40),
|
||||
},
|
||||
{
|
||||
Metadata: proton.ImportMetadata{
|
||||
AddressID: addr[0].ID,
|
||||
LabelIDs: []string{},
|
||||
Flags: proton.MessageFlagReceived,
|
||||
Unread: true,
|
||||
},
|
||||
Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), proton.MaxImportSize),
|
||||
},
|
||||
{
|
||||
Metadata: proton.ImportMetadata{
|
||||
AddressID: addr[0].ID,
|
||||
LabelIDs: []string{},
|
||||
Flags: proton.MessageFlagReceived,
|
||||
Unread: true,
|
||||
},
|
||||
Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), 20),
|
||||
},
|
||||
}
|
||||
|
||||
reqsCopy := make([]proton.ImportReq, len(reqs))
|
||||
for idx := range len(reqs) {
|
||||
messageCopy := make([]byte, len(reqs[idx].Message))
|
||||
copy(messageCopy, reqs[idx].Message)
|
||||
reqsCopy[idx].Message = messageCopy
|
||||
}
|
||||
|
||||
_, err = c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), reqs...)
|
||||
require.ErrorIs(t, err, proton.ErrImportSizeExceeded)
|
||||
|
||||
for idx := range len(reqs) {
|
||||
require.Equal(t, reqsCopy[idx].Message, reqs[idx].Message)
|
||||
require.True(t, reqs[idx].GetEncryptedMessageLength() > 0)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user