diff --git a/message_import.go b/message_import.go index 592ece2..884234c 100644 --- a/message_import.go +++ b/message_import.go @@ -19,8 +19,8 @@ const ( // maxImportCount is the maximum number of messages that can be imported in a single request. maxImportCount = 10 - // maxImportSize is the maximum total request size permitted for a single import request. - maxImportSize = 30 * 1024 * 1024 + // MaxImportSize is the maximum total request size permitted for a single import request. + MaxImportSize = 30 * 1024 * 1024 ) var ErrImportEncrypt = errors.New("failed to encrypt message") @@ -31,23 +31,28 @@ type ImportResStream stream.Stream[ImportRes] // gomock does not support generic func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, workers, buffer int, req ...ImportReq) (ImportResStream, error) { // Encrypt each message. for idx := range req { - enc, err := EncryptRFC822(addrKR, req[idx].Message) + + // Encryption might mangle the original message bufer, so use a copy. + msgCopy := make([]byte, len(req[idx].Message)) + copy(msgCopy, req[idx].Message) + + enc, err := EncryptRFC822(addrKR, msgCopy) if err != nil { return nil, fmt.Errorf("%w %v: %v", ErrImportEncrypt, idx, err) } - req[idx].Message = enc + req[idx].encryptedMessage = enc } // If any of the messages exceed the maximum import size, return an error. - if xslices.Any(req, func(req ImportReq) bool { return len(req.Message) > maxImportSize }) { + if xslices.Any(req, func(req ImportReq) bool { return len(req.encryptedMessage) > MaxImportSize }) { return nil, ErrImportSizeExceeded } return stream.Flatten(parallel.MapStream( ctx, - stream.FromIterator(iterator.Slice(ChunkSized(req, maxImportCount, maxImportSize, func(req ImportReq) int { - return len(req.Message) + stream.FromIterator(iterator.Slice(ChunkSized(req, maxImportCount, MaxImportSize, func(req ImportReq) int { + return len(req.encryptedMessage) }))), workers, buffer, diff --git a/message_import_types.go b/message_import_types.go index 667898b..7f62f67 100644 --- a/message_import_types.go +++ b/message_import_types.go @@ -8,8 +8,13 @@ import ( ) type ImportReq struct { - Metadata ImportMetadata - Message []byte + Metadata ImportMetadata + Message []byte + encryptedMessage []byte +} + +func (r ImportReq) GetEncryptedMessageLength() int { + return len(r.encryptedMessage) } type namedImportReq struct { @@ -42,7 +47,7 @@ func buildImportReqFields(req []namedImportReq) ([]*resty.MultipartField, error) Param: req.Name, FileName: req.Name + ".eml", ContentType: string(rfc822.MessageRFC822), - Stream: resty.NewByteMultipartStream(append(req.Message, "\r\n"...)), + Stream: resty.NewByteMultipartStream(append(req.encryptedMessage, "\r\n"...)), }) } diff --git a/server/helper_test.go b/server/helper_test.go index b17d610..d7575f0 100644 --- a/server/helper_test.go +++ b/server/helper_test.go @@ -2,6 +2,7 @@ package server import ( "fmt" + "strings" "github.com/google/uuid" ) @@ -13,3 +14,8 @@ func newMessageLiteral(from, to string) []byte { func newMessageLiteralWithSubject(from, to, subject string) []byte { return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!", from, to, subject)) } + +func newMessageLiteralWithSubjectAndSize(from, to, subject string, paddingSize int) []byte { + padding := strings.Repeat("A", paddingSize) + return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!Padding:%s", from, to, subject, padding)) +} diff --git a/server/server_test.go b/server/server_test.go index 3195894..628e217 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2700,3 +2700,72 @@ func getFullMessages(ctx context.Context, }, ) } + +func TestServer_Import_MessageSizeExceeded_NoDoubleEncryption(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) + require.NoError(t, err) + + reqs := []proton.ImportReq{ + { + Metadata: proton.ImportMetadata{ + AddressID: addr[0].ID, + LabelIDs: []string{}, + Flags: proton.MessageFlagReceived, + Unread: true, + }, + Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), 40), + }, + { + Metadata: proton.ImportMetadata{ + AddressID: addr[0].ID, + LabelIDs: []string{}, + Flags: proton.MessageFlagReceived, + Unread: true, + }, + Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), proton.MaxImportSize), + }, + { + Metadata: proton.ImportMetadata{ + AddressID: addr[0].ID, + LabelIDs: []string{}, + Flags: proton.MessageFlagReceived, + Unread: true, + }, + Message: newMessageLiteralWithSubjectAndSize("test1@example.com", "test2@example.com", uuid.NewString(), 20), + }, + } + + reqsCopy := make([]proton.ImportReq, len(reqs)) + for idx := range len(reqs) { + messageCopy := make([]byte, len(reqs[idx].Message)) + copy(messageCopy, reqs[idx].Message) + reqsCopy[idx].Message = messageCopy + } + + _, err = c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), reqs...) + require.ErrorIs(t, err, proton.ErrImportSizeExceeded) + + for idx := range len(reqs) { + require.Equal(t, reqsCopy[idx].Message, reqs[idx].Message) + require.True(t, reqs[idx].GetEncryptedMessageLength() > 0) + } + }) + }) +}