mirror of
https://github.com/ProtonMail/go-proton-api.git
synced 2025-12-23 23:57:50 -05:00
feat(GODT-2361): Use simple import encrypter for simple messages
This commit is contained in:
@@ -2,6 +2,8 @@ package proton
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"strings"
|
||||
@@ -9,6 +11,8 @@ import (
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/text/encoding/htmlindex"
|
||||
"golang.org/x/text/encoding/ianaindex"
|
||||
)
|
||||
|
||||
// CharsetReader returns a charset decoder for the given charset.
|
||||
@@ -17,18 +21,220 @@ var CharsetReader func(charset string, input io.Reader) (io.Reader, error)
|
||||
|
||||
// EncryptRFC822 encrypts the given message literal as a PGP attachment.
|
||||
func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
|
||||
msg, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := tryEncrypt(&buf, kr, rfc822.Parse(literal)); err != nil {
|
||||
return encryptFull(kr, literal)
|
||||
}
|
||||
armored, err := msg.GetArmored()
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// tryEncrypt tries to encrypt the given message section.
|
||||
// It first checks if the message is encrypted/signed or has multiple text parts.
|
||||
// If so, it returns an error -- we need to encrypt the whole message as a PGP attachment.
|
||||
func tryEncrypt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
|
||||
var textCount int
|
||||
|
||||
if err := s.Walk(func(s *rfc822.Section) error {
|
||||
// Ensure we can read the content type.
|
||||
contentType, _, err := s.ContentType()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read content type: %w", err)
|
||||
}
|
||||
|
||||
// Ensure we can read the content disposition.
|
||||
if header, err := s.ParseHeader(); err != nil {
|
||||
return fmt.Errorf("cannot read header: %w", err)
|
||||
} else if header.Has("Content-Disposition") {
|
||||
if _, _, err := rfc822.ParseMediaType(header.Get("Content-Disposition")); err != nil {
|
||||
return fmt.Errorf("cannot read content disposition: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the message is already encrypted or signed.
|
||||
if contentType.SubType() == "encrypted" {
|
||||
return fmt.Errorf("already encrypted")
|
||||
} else if contentType.SubType() == "signed" {
|
||||
return fmt.Errorf("already signed")
|
||||
}
|
||||
|
||||
if contentType.Type() != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if textCount++; textCount > 1 {
|
||||
return fmt.Errorf("multiple text parts")
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return encrypt(w, kr, s)
|
||||
}
|
||||
|
||||
// encrypt encrypts the given message section with the given keyring and writes the result to w.
|
||||
func encrypt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
|
||||
contentType, contentParams, err := s.ContentType()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if contentType.IsMultiPart() {
|
||||
return encryptMultipart(w, kr, s, contentParams["boundary"])
|
||||
}
|
||||
|
||||
if contentType.Type() == "text" || contentType.Type() == "message" {
|
||||
return encryptText(w, kr, s)
|
||||
}
|
||||
|
||||
return encryptAtt(w, kr, s)
|
||||
}
|
||||
|
||||
// encryptMultipart encrypts the given multipart message section with the given keyring and writes the result to w.
|
||||
func encryptMultipart(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section, boundary string) error {
|
||||
// Write the header.
|
||||
if _, err := w.Write(s.Header()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new multipart writer with the boundary from the header.
|
||||
ww := rfc822.NewMultipartWriter(w, boundary)
|
||||
|
||||
children, err := s.Children()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encrypt each child part.
|
||||
for _, child := range children {
|
||||
if err := ww.AddPart(func(w io.Writer) error {
|
||||
return encrypt(w, kr, child)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return ww.Done()
|
||||
}
|
||||
|
||||
// encryptText encrypts the given text message section with the given keyring and writes the result to w.
|
||||
func encryptText(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
|
||||
contentType, contentParams, err := s.ContentType()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := s.ParseHeader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := s.DecodedBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the Content-Transfer-Encoding header as we decode the body.
|
||||
header.Del("Content-Transfer-Encoding")
|
||||
|
||||
// If the text part has a charset, decode it to UTF-8.
|
||||
if charset, ok := contentParams["charset"]; ok {
|
||||
decoder, err := getCharsetDecoder(bytes.NewReader(body), charset)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if body, err = io.ReadAll(decoder); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header.Set("Content-Type", mime.FormatMediaType(
|
||||
string(contentType),
|
||||
replace(contentParams, "charset", "utf-8")),
|
||||
)
|
||||
}
|
||||
|
||||
// Encrypt the body.
|
||||
enc, err := kr.Encrypt(crypto.NewPlainMessage(body), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Armor the encrypted body.
|
||||
arm, err := enc.GetArmored()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the header.
|
||||
if _, err := w.Write(header.Raw()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the armored body.
|
||||
if _, err := w.Write([]byte(arm)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptAtt encrypts the given attachment section with the given keyring and writes the result to w.
|
||||
func encryptAtt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error {
|
||||
header, err := s.ParseHeader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := s.DecodedBody()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the Content-Transfer-Encoding header to base64.
|
||||
header.Set("Content-Transfer-Encoding", "base64")
|
||||
|
||||
// Encrypt the body.
|
||||
enc, err := kr.Encrypt(crypto.NewPlainMessage(body), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode the encrypted body to base64.
|
||||
b64, err := getBase64(enc.GetBinary())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the header.
|
||||
if _, err := w.Write(header.Raw()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the base64 body.
|
||||
if _, err := w.Write(b64); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptFull builds a PGP/MIME encrypted message from the given literal.
|
||||
func encryptFull(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
|
||||
enc, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, _ := rfc822.Split(literal)
|
||||
arm, err := enc.GetArmored()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headerParsed, err := rfc822.NewHeader(header)
|
||||
header, err := rfc822.Parse(literal).ParseHeader()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -40,7 +246,7 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
|
||||
{
|
||||
newHeader := rfc822.NewEmptyHeader()
|
||||
|
||||
if value, ok := headerParsed.GetChecked("Message-Id"); ok {
|
||||
if value, ok := header.GetChecked("Message-Id"); ok {
|
||||
newHeader.Set("Message-Id", value)
|
||||
}
|
||||
|
||||
@@ -51,23 +257,23 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
|
||||
newHeader.Set("Mime-version", "1.0")
|
||||
newHeader.Set("Content-Type", contentType)
|
||||
|
||||
if value, ok := headerParsed.GetChecked("From"); ok {
|
||||
if value, ok := header.GetChecked("From"); ok {
|
||||
newHeader.Set("From", value)
|
||||
}
|
||||
|
||||
if value, ok := headerParsed.GetChecked("To"); ok {
|
||||
if value, ok := header.GetChecked("To"); ok {
|
||||
newHeader.Set("To", value)
|
||||
}
|
||||
|
||||
if value, ok := headerParsed.GetChecked("Subject"); ok {
|
||||
if value, ok := header.GetChecked("Subject"); ok {
|
||||
newHeader.Set("Subject", value)
|
||||
}
|
||||
|
||||
if value, ok := headerParsed.GetChecked("Date"); ok {
|
||||
if value, ok := header.GetChecked("Date"); ok {
|
||||
newHeader.Set("Date", value)
|
||||
}
|
||||
|
||||
if value, ok := headerParsed.GetChecked("Received"); ok {
|
||||
if value, ok := header.GetChecked("Received"); ok {
|
||||
newHeader.Set("Received", value)
|
||||
}
|
||||
|
||||
@@ -107,7 +313,7 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := writer.Write([]byte(armored))
|
||||
_, err := writer.Write([]byte(arm))
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@@ -121,3 +327,50 @@ func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func getBase64(b []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := encode(base64.NewEncoder(base64.StdEncoding, &buf), b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func encode(wc io.WriteCloser, b []byte) error {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if _, err := buf.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wc.Close()
|
||||
}
|
||||
|
||||
func getCharsetDecoder(r io.Reader, charset string) (io.Reader, error) {
|
||||
if CharsetReader != nil {
|
||||
if enc, err := CharsetReader(charset, r); err == nil {
|
||||
return enc, nil
|
||||
}
|
||||
}
|
||||
|
||||
if enc, err := ianaindex.MIME.Encoding(strings.ToLower(charset)); err == nil {
|
||||
return enc.NewDecoder().Reader(r), nil
|
||||
}
|
||||
|
||||
if enc, err := ianaindex.MIME.Encoding("cs" + strings.ToLower(charset)); err == nil {
|
||||
return enc.NewDecoder().Reader(r), nil
|
||||
}
|
||||
|
||||
if enc, err := htmlindex.Get(strings.ToLower(charset)); err == nil {
|
||||
return enc.NewDecoder().Reader(r), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown charset: %s", charset)
|
||||
}
|
||||
|
||||
func replace[Key comparable, Value any](m map[Key]Value, key Key, value Value) map[Key]Value {
|
||||
m[key] = value
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -10,7 +10,50 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncryptMessage(t *testing.T) {
|
||||
func TestEncryptMessage_Simple(t *testing.T) {
|
||||
const message = `From: Nathaniel Borenstein <nsb@bellcore.com>
|
||||
To: Ned Freed <ned@innosoft.com>
|
||||
Subject: Sample message (import 2)
|
||||
MIME-Version: 1.0
|
||||
Content-type: text/plain
|
||||
|
||||
This is explicitly typed plain ASCII text.
|
||||
`
|
||||
key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
kr, err := crypto.NewKeyRing(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
encryptedMessage, err := EncryptRFC822(kr, []byte(message))
|
||||
require.NoError(t, err)
|
||||
|
||||
section := rfc822.Parse(encryptedMessage)
|
||||
|
||||
// Check root header:
|
||||
header, err := section.ParseHeader()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, header.Get("From"), "Nathaniel Borenstein <nsb@bellcore.com>")
|
||||
assert.Equal(t, header.Get("To"), "Ned Freed <ned@innosoft.com>")
|
||||
assert.Equal(t, header.Get("Subject"), "Sample message (import 2)")
|
||||
assert.Equal(t, header.Get("MIME-Version"), "1.0")
|
||||
|
||||
// Read the body.
|
||||
body, err := section.DecodedBody()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarmor the PGP message.
|
||||
enc, err := crypto.NewPGPMessageFromArmored(string(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt the PGP message.
|
||||
dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "This is explicitly typed plain ASCII text.\n", dec.GetString())
|
||||
}
|
||||
|
||||
func TestEncryptMessage_MultipleTextParts(t *testing.T) {
|
||||
const message = `From: Nathaniel Borenstein <nsb@bellcore.com>
|
||||
To: Ned Freed <ned@innosoft.com>
|
||||
Subject: Sample message (import 2)
|
||||
|
||||
@@ -400,7 +400,9 @@ func (c *NetCtl) NewRoundTripper(tlsConfig *tls.Config) http.RoundTripper {
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return c.dial(ctx, &tls.Dialer{Config: tlsConfig}, network, addr)
|
||||
},
|
||||
TLSClientConfig: tlsConfig,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: time.Second,
|
||||
ExpectContinueTimeout: time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -527,12 +527,10 @@ func (s *Server) importAttachment(userID, messageID string, att *rfc822.Section)
|
||||
|
||||
var disposition, filename string
|
||||
|
||||
if header.Has("Content-Disposition") {
|
||||
dispType, dispParams, err := mime.ParseMediaType(header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
return proton.Attachment{}, fmt.Errorf("failed to parse attachment content disposition: %w", err)
|
||||
}
|
||||
|
||||
if !header.Has("Content-Disposition") {
|
||||
disposition = "attachment"
|
||||
filename = "attachment.bin"
|
||||
} else if dispType, dispParams, err := mime.ParseMediaType(header.Get("Content-Disposition")); err == nil {
|
||||
disposition = dispType
|
||||
filename = dispParams["filename"]
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user