feat(GODT-2361): Use simple import encrypter for simple messages

This commit is contained in:
James Houlahan
2023-02-15 23:54:44 +01:00
committed by James
parent 3e1eb7e617
commit 0af5d2f084
4 changed files with 317 additions and 21 deletions

View File

@@ -2,6 +2,8 @@ package proton
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"mime"
"strings"
@@ -9,6 +11,8 @@ import (
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/google/uuid"
"golang.org/x/text/encoding/htmlindex"
"golang.org/x/text/encoding/ianaindex"
)
// CharsetReader returns a charset decoder for the given charset.
@@ -17,18 +21,220 @@ var CharsetReader func(charset string, input io.Reader) (io.Reader, error)
// EncryptRFC822 encrypts the given message literal as a PGP attachment.
func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
msg, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr)
if err != nil {
return nil, err
var buf bytes.Buffer
if err := tryEncrypt(&buf, kr, rfc822.Parse(literal)); err != nil {
return encryptFull(kr, literal)
}
armored, err := msg.GetArmored()
return buf.Bytes(), nil
}
// tryEncrypt tries to encrypt the given message section.
// It first checks if the message is encrypted/signed or has multiple text parts.
// If so, it returns an error -- we need to encrypt the whole message as a PGP attachment.
func tryEncrypt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
var textCount int
if err := s.Walk(func(s *rfc822.Section) error {
// Ensure we can read the content type.
contentType, _, err := s.ContentType()
if err != nil {
return fmt.Errorf("cannot read content type: %w", err)
}
// Ensure we can read the content disposition.
if header, err := s.ParseHeader(); err != nil {
return fmt.Errorf("cannot read header: %w", err)
} else if header.Has("Content-Disposition") {
if _, _, err := rfc822.ParseMediaType(header.Get("Content-Disposition")); err != nil {
return fmt.Errorf("cannot read content disposition: %w", err)
}
}
// Check if the message is already encrypted or signed.
if contentType.SubType() == "encrypted" {
return fmt.Errorf("already encrypted")
} else if contentType.SubType() == "signed" {
return fmt.Errorf("already signed")
}
if contentType.Type() != "text" {
return nil
}
if textCount++; textCount > 1 {
return fmt.Errorf("multiple text parts")
}
return nil
}); err != nil {
return err
}
return encrypt(w, kr, s)
}
// encrypt encrypts the given message section with the given keyring and writes the result to w.
func encrypt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
contentType, contentParams, err := s.ContentType()
if err != nil {
return err
}
if contentType.IsMultiPart() {
return encryptMultipart(w, kr, s, contentParams["boundary"])
}
if contentType.Type() == "text" || contentType.Type() == "message" {
return encryptText(w, kr, s)
}
return encryptAtt(w, kr, s)
}
// encryptMultipart encrypts the given multipart message section with the given keyring and writes the result to w.
func encryptMultipart(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section, boundary string) error {
// Write the header.
if _, err := w.Write(s.Header()); err != nil {
return err
}
// Create a new multipart writer with the boundary from the header.
ww := rfc822.NewMultipartWriter(w, boundary)
children, err := s.Children()
if err != nil {
return err
}
// Encrypt each child part.
for _, child := range children {
if err := ww.AddPart(func(w io.Writer) error {
return encrypt(w, kr, child)
}); err != nil {
return err
}
}
return ww.Done()
}
// encryptText encrypts the given text message section with the given keyring and writes the result to w.
func encryptText(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
contentType, contentParams, err := s.ContentType()
if err != nil {
return err
}
header, err := s.ParseHeader()
if err != nil {
return err
}
body, err := s.DecodedBody()
if err != nil {
return err
}
// Remove the Content-Transfer-Encoding header as we decode the body.
header.Del("Content-Transfer-Encoding")
// If the text part has a charset, decode it to UTF-8.
if charset, ok := contentParams["charset"]; ok {
decoder, err := getCharsetDecoder(bytes.NewReader(body), charset)
if err != nil {
return err
}
if body, err = io.ReadAll(decoder); err != nil {
return err
}
header.Set("Content-Type", mime.FormatMediaType(
string(contentType),
replace(contentParams, "charset", "utf-8")),
)
}
// Encrypt the body.
enc, err := kr.Encrypt(crypto.NewPlainMessage(body), nil)
if err != nil {
return err
}
// Armor the encrypted body.
arm, err := enc.GetArmored()
if err != nil {
return err
}
// Write the header.
if _, err := w.Write(header.Raw()); err != nil {
return err
}
// Write the armored body.
if _, err := w.Write([]byte(arm)); err != nil {
return err
}
return nil
}
// encryptAtt encrypts the given attachment section with the given keyring and writes the result to w.
func encryptAtt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
header, err := s.ParseHeader()
if err != nil {
return err
}
body, err := s.DecodedBody()
if err != nil {
return err
}
// Set the Content-Transfer-Encoding header to base64.
header.Set("Content-Transfer-Encoding", "base64")
// Encrypt the body.
enc, err := kr.Encrypt(crypto.NewPlainMessage(body), nil)
if err != nil {
return err
}
// Encode the encrypted body to base64.
b64, err := getBase64(enc.GetBinary())
if err != nil {
return err
}
// Write the header.
if _, err := w.Write(header.Raw()); err != nil {
return err
}
// Write the base64 body.
if _, err := w.Write(b64); err != nil {
return err
}
return nil
}
// encryptFull builds a PGP/MIME encrypted message from the given literal.
func encryptFull(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
enc, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr)
if err != nil {
return nil, err
}
header, _ := rfc822.Split(literal)
arm, err := enc.GetArmored()
if err != nil {
return nil, err
}
headerParsed, err := rfc822.NewHeader(header)
header, err := rfc822.Parse(literal).ParseHeader()
if err != nil {
return nil, err
}
@@ -40,7 +246,7 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
{
newHeader := rfc822.NewEmptyHeader()
if value, ok := headerParsed.GetChecked("Message-Id"); ok {
if value, ok := header.GetChecked("Message-Id"); ok {
newHeader.Set("Message-Id", value)
}
@@ -51,23 +257,23 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
newHeader.Set("Mime-version", "1.0")
newHeader.Set("Content-Type", contentType)
if value, ok := headerParsed.GetChecked("From"); ok {
if value, ok := header.GetChecked("From"); ok {
newHeader.Set("From", value)
}
if value, ok := headerParsed.GetChecked("To"); ok {
if value, ok := header.GetChecked("To"); ok {
newHeader.Set("To", value)
}
if value, ok := headerParsed.GetChecked("Subject"); ok {
if value, ok := header.GetChecked("Subject"); ok {
newHeader.Set("Subject", value)
}
if value, ok := headerParsed.GetChecked("Date"); ok {
if value, ok := header.GetChecked("Date"); ok {
newHeader.Set("Date", value)
}
if value, ok := headerParsed.GetChecked("Received"); ok {
if value, ok := header.GetChecked("Received"); ok {
newHeader.Set("Received", value)
}
@@ -107,7 +313,7 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
return err
}
_, err := writer.Write([]byte(armored))
_, err := writer.Write([]byte(arm))
return err
}); err != nil {
return nil, err
@@ -121,3 +327,50 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
return buf.Bytes(), nil
}
func getBase64(b []byte) ([]byte, error) {
var buf bytes.Buffer
if err := encode(base64.NewEncoder(base64.StdEncoding, &buf), b); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func encode(wc io.WriteCloser, b []byte) error {
var buf bytes.Buffer
if _, err := buf.Write(b); err != nil {
return err
}
return wc.Close()
}
func getCharsetDecoder(r io.Reader, charset string) (io.Reader, error) {
if CharsetReader != nil {
if enc, err := CharsetReader(charset, r); err == nil {
return enc, nil
}
}
if enc, err := ianaindex.MIME.Encoding(strings.ToLower(charset)); err == nil {
return enc.NewDecoder().Reader(r), nil
}
if enc, err := ianaindex.MIME.Encoding("cs" + strings.ToLower(charset)); err == nil {
return enc.NewDecoder().Reader(r), nil
}
if enc, err := htmlindex.Get(strings.ToLower(charset)); err == nil {
return enc.NewDecoder().Reader(r), nil
}
return nil, fmt.Errorf("unknown charset: %s", charset)
}
func replace[Key comparable, Value any](m map[Key]Value, key Key, value Value) map[Key]Value {
m[key] = value
return m
}

View File

@@ -10,7 +10,50 @@ import (
"github.com/stretchr/testify/require"
)
func TestEncryptMessage(t *testing.T) {
func TestEncryptMessage_Simple(t *testing.T) {
const message = `From: Nathaniel Borenstein <nsb@bellcore.com>
To: Ned Freed <ned@innosoft.com>
Subject: Sample message (import 2)
MIME-Version: 1.0
Content-type: text/plain
This is explicitly typed plain ASCII text.
`
key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0)
require.NoError(t, err)
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
encryptedMessage, err := EncryptRFC822(kr, []byte(message))
require.NoError(t, err)
section := rfc822.Parse(encryptedMessage)
// Check root header:
header, err := section.ParseHeader()
require.NoError(t, err)
assert.Equal(t, header.Get("From"), "Nathaniel Borenstein <nsb@bellcore.com>")
assert.Equal(t, header.Get("To"), "Ned Freed <ned@innosoft.com>")
assert.Equal(t, header.Get("Subject"), "Sample message (import 2)")
assert.Equal(t, header.Get("MIME-Version"), "1.0")
// Read the body.
body, err := section.DecodedBody()
require.NoError(t, err)
// Unarmor the PGP message.
enc, err := crypto.NewPGPMessageFromArmored(string(body))
require.NoError(t, err)
// Decrypt the PGP message.
dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime())
require.NoError(t, err)
require.Equal(t, "This is explicitly typed plain ASCII text.\n", dec.GetString())
}
func TestEncryptMessage_MultipleTextParts(t *testing.T) {
const message = `From: Nathaniel Borenstein <nsb@bellcore.com>
To: Ned Freed <ned@innosoft.com>
Subject: Sample message (import 2)

View File

@@ -400,7 +400,9 @@ func (c *NetCtl) NewRoundTripper(tlsConfig *tls.Config) http.RoundTripper {
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return c.dial(ctx, &tls.Dialer{Config: tlsConfig}, network, addr)
},
TLSClientConfig: tlsConfig,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: time.Second,
ExpectContinueTimeout: time.Second,
}
}

View File

@@ -527,12 +527,10 @@ func (s *Server) importAttachment(userID, messageID string, att *rfc822.Section)
var disposition, filename string
if header.Has("Content-Disposition") {
dispType, dispParams, err := mime.ParseMediaType(header.Get("Content-Disposition"))
if err != nil {
return proton.Attachment{}, fmt.Errorf("failed to parse attachment content disposition: %w", err)
}
if !header.Has("Content-Disposition") {
disposition = "attachment"
filename = "attachment.bin"
} else if dispType, dispParams, err := mime.ParseMediaType(header.Get("Content-Disposition")); err == nil {
disposition = dispType
filename = dispParams["filename"]
} else {