fix(IMEX-66): prevent potential double encryption on import

This commit is contained in:
ElectroNafta
2025-04-22 13:13:01 +02:00
committed by Atanas Janeshliev
parent e624a080f7
commit c9726b8d6e
4 changed files with 95 additions and 10 deletions

View File

@@ -19,8 +19,8 @@ const (
// maxImportCount is the maximum number of messages that can be imported in a single request.
maxImportCount = 10
// maxImportSize is the maximum total request size permitted for a single import request.
maxImportSize = 30 * 1024 * 1024
// MaxImportSize is the maximum total request size permitted for a single import request.
MaxImportSize = 30 * 1024 * 1024
)
var ErrImportEncrypt = errors.New("failed to encrypt message")
@@ -31,23 +31,28 @@ type ImportResStream stream.Stream[ImportRes] // gomock does not support generic
func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, workers, buffer int, req ...ImportReq) (ImportResStream, error) {
// Encrypt each message.
for idx := range req {
enc, err := EncryptRFC822(addrKR, req[idx].Message)
// Encryption might mangle the original message bufer, so use a copy.
msgCopy := make([]byte, len(req[idx].Message))
copy(msgCopy, req[idx].Message)
enc, err := EncryptRFC822(addrKR, msgCopy)
if err != nil {
return nil, fmt.Errorf("%w %v: %v", ErrImportEncrypt, idx, err)
}
req[idx].Message = enc
req[idx].encryptedMessage = enc
}
// If any of the messages exceed the maximum import size, return an error.
if xslices.Any(req, func(req ImportReq) bool { return len(req.Message) > maxImportSize }) {
if xslices.Any(req, func(req ImportReq) bool { return len(req.encryptedMessage) > MaxImportSize }) {
return nil, ErrImportSizeExceeded
}
return stream.Flatten(parallel.MapStream(
ctx,
stream.FromIterator(iterator.Slice(ChunkSized(req, maxImportCount, maxImportSize, func(req ImportReq) int {
return len(req.Message)
stream.FromIterator(iterator.Slice(ChunkSized(req, maxImportCount, MaxImportSize, func(req ImportReq) int {
return len(req.encryptedMessage)
}))),
workers,
buffer,

View File

@@ -8,8 +8,13 @@ import (
)
type ImportReq struct {
Metadata ImportMetadata
Message []byte
Metadata ImportMetadata
Message []byte
encryptedMessage []byte
}
func (r ImportReq) GetEncryptedMessageLength() int {
return len(r.encryptedMessage)
}
type namedImportReq struct {
@@ -42,7 +47,7 @@ func buildImportReqFields(req []namedImportReq) ([]*resty.MultipartField, error)
Param: req.Name,
FileName: req.Name + ".eml",
ContentType: string(rfc822.MessageRFC822),
Stream: resty.NewByteMultipartStream(append(req.Message, "\r\n"...)),
Stream: resty.NewByteMultipartStream(append(req.encryptedMessage, "\r\n"...)),
})
}

View File

@@ -2,6 +2,7 @@ package server
import (
"fmt"
"strings"
"github.com/google/uuid"
)
@@ -13,3 +14,8 @@ func newMessageLiteral(from, to string) []byte {
func newMessageLiteralWithSubject(from, to, subject string) []byte {
return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!", from, to, subject))
}
func newMessageLiteralWithSubjectAndSize(from, to, subject string, paddingSize int) []byte {
padding := strings.Repeat("A", paddingSize)
return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!Padding:%s", from, to, subject, padding))
}

View File

@@ -2700,3 +2700,72 @@ func getFullMessages(ctx context.Context,
},
)
}
func TestServer_Import_MessageSizeExceeded_NoDoubleEncryption(t *testing.T) {
withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) {
withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{})
require.NoError(t, err)
reqs := []proton.ImportReq{
{
Metadata: proton.ImportMetadata{
AddressID: addr[0].ID,
LabelIDs: []string{},
Flags: proton.MessageFlagReceived,
Unread: true,
},
Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), 40),
},
{
Metadata: proton.ImportMetadata{
AddressID: addr[0].ID,
LabelIDs: []string{},
Flags: proton.MessageFlagReceived,
Unread: true,
},
Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), proton.MaxImportSize),
},
{
Metadata: proton.ImportMetadata{
AddressID: addr[0].ID,
LabelIDs: []string{},
Flags: proton.MessageFlagReceived,
Unread: true,
},
Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), 20),
},
}
reqsCopy := make([]proton.ImportReq, len(reqs))
for idx := range len(reqs) {
messageCopy := make([]byte, len(reqs[idx].Message))
copy(messageCopy, reqs[idx].Message)
reqsCopy[idx].Message = messageCopy
}
_, err = c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), reqs...)
require.ErrorIs(t, err, proton.ErrImportSizeExceeded)
for idx := range len(reqs) {
require.Equal(t, reqsCopy[idx].Message, reqs[idx].Message)
require.True(t, reqs[idx].GetEncryptedMessageLength() > 0)
}
})
})
}