feat: Initial open source commit

This commit is contained in:
James Houlahan
2022-11-23 11:17:54 +01:00
commit 2323ea7360
133 changed files with 16622 additions and 0 deletions

28
.github/workflows/check.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
name: Lint and Test
on: push
jobs:
check:
runs-on: ubuntu-latest
steps:
- name: Get sources
uses: actions/checkout@v3
- name: Set up Go 1.18
uses: actions/setup-go@v3
with:
go-version: '1.18'
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.50.0
args: --timeout=180s
skip-cache: true
- name: Run tests
run: go test -v ./...
- name: Run tests with race check
run: go test -v -race ./...

10
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,10 @@
# Contribution Policy
By making a contribution to this project:
1. I assign any and all copyright related to the contribution to Proton AG;
2. I certify that the contribution was created in whole by me;
3. I understand and agree that this project and the contribution are public
and that a record of the contribution (including all personal information I
submit with it) is maintained indefinitely and may be redistributed with
this project or the open source license(s) involved.

22
LICENSE Normal file
View File

@@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2020 James Houlahan
Copyright (c) 2022 Proton AG
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

23
README.md Normal file
View File

@@ -0,0 +1,23 @@
# Go Proton API
<a href="https://github.com/ProtonMail/go-proton-api/actions/workflows/check.yml"><img src="https://github.com/ProtonMail/go-proton-api/actions/workflows/check.yml/badge.svg?branch=master" alt="CI Status"></a>
<a href="https://pkg.go.dev/github.com/ProtonMail/go-proton-api"><img src="https://pkg.go.dev/badge/github.com/ProtonMail/go-proton-api" alt="GoDoc"></a>
<a href="https://goreportcard.com/report/ProtonMail/go-proton-api"><img src="https://goreportcard.com/badge/ProtonMail/go-proton-api" alt="Go Report Card"></a>
<a href="LICENSE"><img src="https://img.shields.io/github/license/ProtonMail/go-proton-api.svg" alt="License"></a>
This repository holds Go Proton API, a Go library implementing a client and development server for (a subset of) the Proton REST API.
The license can be found in the [LICENSE](./LICENSE) file.
For the contribution policy, see [CONTRIBUTING](./CONTRIBUTING.md).
## Environment variables
Most of the integration tests run locally. The ones that interact with Proton servers require the following environment variables set:
- ```GO_PROTON_API_TEST_USERNAME```
- ```GO_PROTON_API_TEST_PASSWORD```
## Contribution
The library is maintained by Proton AG, and is not actively looking for contributors.

46
address.go Normal file
View File

@@ -0,0 +1,46 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
"golang.org/x/exp/slices"
)
func (c *Client) GetAddresses(ctx context.Context) ([]Address, error) {
var res struct {
Addresses []Address
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/core/v4/addresses")
}); err != nil {
return nil, err
}
slices.SortFunc(res.Addresses, func(a, b Address) bool {
return a.Order < b.Order
})
return res.Addresses, nil
}
func (c *Client) GetAddress(ctx context.Context, addressID string) (Address, error) {
var res struct {
Address Address
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/core/v4/addresses/" + addressID)
}); err != nil {
return Address{}, err
}
return res.Address, nil
}
func (c *Client) OrderAddresses(ctx context.Context, req OrderAddressesReq) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/core/v4/addresses/order")
})
}

27
address_types.go Normal file
View File

@@ -0,0 +1,27 @@
package proton
type Address struct {
ID string
Email string
Send Bool
Receive Bool
Status AddressStatus
Order int
DisplayName string
Keys Keys
}
type OrderAddressesReq struct {
AddressIDs []string
}
type AddressStatus int
const (
AddressStatusDisabled AddressStatus = iota
AddressStatusEnabled
AddressStatusDeleting
)

33
atomic.go Normal file
View File

@@ -0,0 +1,33 @@
package proton
import "sync/atomic"
type atomicUint64 struct {
v uint64
}
func (x *atomicUint64) Load() uint64 { return atomic.LoadUint64(&x.v) }
func (x *atomicUint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) }
func (x *atomicUint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) }
func (x *atomicUint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) }
type atomicBool struct {
v uint32
}
func (x *atomicBool) Load() bool { return atomic.LoadUint32(&x.v) != 0 }
func (x *atomicBool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) }
func (x *atomicBool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 }
func b32(b bool) uint32 {
if b {
return 1
}
return 0
}

79
attachment.go Normal file
View File

@@ -0,0 +1,79 @@
package proton
import (
"bytes"
"context"
"fmt"
"io"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetAttachment(ctx context.Context, attachmentID string) ([]byte, error) {
return c.attPool().ProcessOne(ctx, attachmentID)
}
func (c *Client) UploadAttachment(ctx context.Context, addrKR *crypto.KeyRing, req CreateAttachmentReq) (Attachment, error) {
var res struct {
Attachment Attachment
}
sig, err := addrKR.SignDetached(crypto.NewPlainMessage(req.Body))
if err != nil {
return Attachment{}, fmt.Errorf("failed to sign attachment: %w", err)
}
enc, err := addrKR.EncryptAttachment(crypto.NewPlainMessage(req.Body), req.Filename)
if err != nil {
return Attachment{}, fmt.Errorf("failed to encrypt attachment: %w", err)
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).
SetMultipartFormData(map[string]string{
"MessageID": req.MessageID,
"Filename": req.Filename,
"MIMEType": string(req.MIMEType),
"Disposition": string(req.Disposition),
"ContentID": req.ContentID,
}).
SetMultipartFields(
&resty.MultipartField{
Param: "KeyPackets",
FileName: "blob",
ContentType: "application/octet-stream",
Reader: bytes.NewReader(enc.KeyPacket),
},
&resty.MultipartField{
Param: "DataPacket",
FileName: "blob",
ContentType: "application/octet-stream",
Reader: bytes.NewReader(enc.DataPacket),
},
&resty.MultipartField{
Param: "Signature",
FileName: "blob",
ContentType: "application/octet-stream",
Reader: bytes.NewReader(sig.GetBinary()),
},
).
Post("/mail/v4/attachments")
}); err != nil {
return Attachment{}, err
}
return res.Attachment, nil
}
func (c *Client) getAttachment(ctx context.Context, attachmentID string) ([]byte, error) {
res, err := c.doRes(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetDoNotParseResponse(true).Get("/mail/v4/attachments/" + attachmentID)
})
if err != nil {
return nil, err
}
defer res.RawBody().Close()
return io.ReadAll(res.RawBody())
}

36
attachment_types.go Normal file
View File

@@ -0,0 +1,36 @@
package proton
import (
"github.com/ProtonMail/gluon/rfc822"
)
type Attachment struct {
ID string
Name string
Size int64
MIMEType rfc822.MIMEType
Disposition Disposition
Headers Headers
KeyPackets string
Signature string
}
type Disposition string
const (
InlineDisposition Disposition = "inline"
AttachmentDisposition Disposition = "attachment"
)
type CreateAttachmentReq struct {
MessageID string
Filename string
MIMEType rfc822.MIMEType
Disposition Disposition
ContentID string
Body []byte
}

45
auth.go Normal file
View File

@@ -0,0 +1,45 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
)
func (c *Client) Auth2FA(ctx context.Context, req Auth2FAReq) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Post("/core/v4/auth/2fa")
})
}
func (c *Client) AuthDelete(ctx context.Context) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/core/v4/auth")
})
}
func (c *Client) AuthSessions(ctx context.Context) ([]AuthSession, error) {
var res struct {
Sessions []AuthSession
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/auth/v4/sessions")
}); err != nil {
return nil, err
}
return res.Sessions, nil
}
func (c *Client) AuthRevoke(ctx context.Context, authUID string) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/auth/v4/sessions/" + authUID)
})
}
func (c *Client) AuthRevokeAll(ctx context.Context) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/auth/v4/sessions")
})
}

103
auth_test.go Normal file
View File

@@ -0,0 +1,103 @@
package proton_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
)
func TestAutomaticAuthRefresh(t *testing.T) {
wantAuth := proton.Auth{
UID: "testUID",
AccessToken: "testAcc",
RefreshToken: "testRef",
}
mux := http.NewServeMux()
mux.HandleFunc("/core/v4/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(wantAuth); err != nil {
panic(err)
}
})
mux.HandleFunc("/core/v4/users", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
ts := httptest.NewServer(mux)
defer ts.Close()
var gotAuth proton.Auth
// Create a new client.
c := proton.New(proton.WithHostURL(ts.URL)).NewClient("uid", "acc", "ref", time.Now().Add(-time.Second))
defer c.Close()
// Register an auth handler.
c.AddAuthHandler(func(auth proton.Auth) { gotAuth = auth })
// Make a request with an access token that already expired one second ago.
if _, err := c.GetUser(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
}
// The auth callback should have been called.
if !cmp.Equal(gotAuth, wantAuth) {
t.Fatal("got unexpected auth", gotAuth)
}
}
func TestAuth(t *testing.T) {
s := server.New()
defer s.Close()
_, _, err := s.CreateUser("username", "email@pm.me", []byte("password"))
require.NoError(t, err)
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
// Create one session.
c1, auth1, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
require.NoError(t, err)
// Revoke all other sessions.
require.NoError(t, c1.AuthRevokeAll(context.Background()))
// Create another session.
c2, _, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
require.NoError(t, err)
// There should be two sessions.
sessions, err := c1.AuthSessions(context.Background())
require.NoError(t, err)
require.Len(t, sessions, 2)
// Revoke the first session.
require.NoError(t, c2.AuthRevoke(context.Background(), auth1.UID))
// The first session should no longer work.
require.Error(t, c1.AuthDelete(context.Background()))
// There should be one session remaining.
remaining, err := c2.AuthSessions(context.Background())
require.NoError(t, err)
require.Len(t, remaining, 1)
// Delete the last session.
require.NoError(t, c2.AuthDelete(context.Background()))
}

95
auth_types.go Normal file
View File

@@ -0,0 +1,95 @@
package proton
type AuthInfoReq struct {
Username string
}
type AuthInfo struct {
Version int
Modulus string
ServerEphemeral string
Salt string
SRPSession string
TwoFA TwoFAInfo `json:"2FA"`
}
type U2FReq struct {
KeyHandle string
ClientData string
SignatureData string
}
type AuthReq struct {
Username string
ClientEphemeral string
ClientProof string
SRPSession string
U2F U2FReq
}
type Auth struct {
UserID string
UID string
AccessToken string
RefreshToken string
ServerProof string
ExpiresIn int
Scope string
TwoFA TwoFAInfo `json:"2FA"`
PasswordMode PasswordMode
}
type RegisteredKey struct {
Version string
KeyHandle string
}
type U2FInfo struct {
Challenge string
RegisteredKeys []RegisteredKey
}
type TwoFAInfo struct {
Enabled TwoFAStatus
U2F U2FInfo
}
type TwoFAStatus int
const (
TwoFADisabled TwoFAStatus = iota
TOTPEnabled
)
type PasswordMode int
const (
OnePasswordMode PasswordMode = iota + 1
TwoPasswordMode
)
type Auth2FAReq struct {
TwoFactorCode string
}
type AuthRefreshReq struct {
UID string
RefreshToken string
ResponseType string
GrantType string
RedirectURI string
State string
}
type AuthSession struct {
UID string
CreateTime int64
ClientID string
MemberID string
Revocable Bool
LocalizedClientName string
}

19
block.go Normal file
View File

@@ -0,0 +1,19 @@
package proton
import (
"context"
"io"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetBlock(ctx context.Context, url string) (io.ReadCloser, error) {
res, err := c.doRes(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetDoNotParseResponse(true).Get(url)
})
if err != nil {
return nil, err
}
return res.RawBody(), nil
}

38
boolean.go Normal file
View File

@@ -0,0 +1,38 @@
package proton
import "encoding/json"
// Bool is a convenience type for boolean values; it converts from APIBool to Go's builtin bool type.
type Bool bool
// APIBool is the boolean type used by the API (0 or 1).
type APIBool int
const (
APIFalse APIBool = iota
APITrue
)
func (b *Bool) UnmarshalJSON(data []byte) error {
var v APIBool
if err := json.Unmarshal(data, &v); err != nil {
return err
}
*b = Bool(v == APITrue)
return nil
}
func (b Bool) MarshalJSON() ([]byte, error) {
var v APIBool
if b {
v = APITrue
} else {
v = APIFalse
}
return json.Marshal(v)
}

77
calendar.go Normal file
View File

@@ -0,0 +1,77 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetCalendars(ctx context.Context) ([]Calendar, error) {
var res struct {
Calendars []Calendar
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1")
}); err != nil {
return nil, err
}
return res.Calendars, nil
}
func (c *Client) GetCalendar(ctx context.Context, calendarID string) (Calendar, error) {
var res struct {
Calendar Calendar
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1/" + calendarID)
}); err != nil {
return Calendar{}, err
}
return res.Calendar, nil
}
func (c *Client) GetCalendarKeys(ctx context.Context, calendarID string) (CalendarKeys, error) {
var res struct {
Keys CalendarKeys
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/keys")
}); err != nil {
return nil, err
}
return res.Keys, nil
}
func (c *Client) GetCalendarMembers(ctx context.Context, calendarID string) ([]CalendarMember, error) {
var res struct {
Members []CalendarMember
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/members")
}); err != nil {
return nil, err
}
return res.Members, nil
}
func (c *Client) GetCalendarPassphrase(ctx context.Context, calendarID string) (CalendarPassphrase, error) {
var res struct {
Passphrase CalendarPassphrase
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/passphrase")
}); err != nil {
return CalendarPassphrase{}, err
}
return res.Passphrase, nil
}

66
calendar_event.go Normal file
View File

@@ -0,0 +1,66 @@
package proton
import (
"context"
"net/url"
"strconv"
"github.com/go-resty/resty/v2"
)
func (c *Client) CountCalendarEvents(ctx context.Context, calendarID string) (int, error) {
var res struct {
Total int
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/events")
}); err != nil {
return 0, err
}
return res.Total, nil
}
// TODO: For now, the query params are partially constant -- should they be configurable?
func (c *Client) GetCalendarEvents(ctx context.Context, calendarID string, page, pageSize int, filter url.Values) ([]CalendarEvent, error) {
var res struct {
Events []CalendarEvent
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParams(map[string]string{
"Page": strconv.Itoa(page),
"PageSize": strconv.Itoa(pageSize),
}).SetQueryParamsFromValues(filter).SetResult(&res).Get("/calendar/v1/" + calendarID + "/events")
}); err != nil {
return nil, err
}
return res.Events, nil
}
func (c *Client) GetAllCalendarEvents(ctx context.Context, calendarID string, filter url.Values) ([]CalendarEvent, error) {
total, err := c.CountCalendarEvents(ctx, calendarID)
if err != nil {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]CalendarEvent, error) {
return c.GetCalendarEvents(ctx, calendarID, page, pageSize, filter)
})
}
func (c *Client) GetCalendarEvent(ctx context.Context, calendarID, eventID string) (CalendarEvent, error) {
var res struct {
Event CalendarEvent
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/events/" + eventID)
}); err != nil {
return CalendarEvent{}, err
}
return res.Event, nil
}

110
calendar_event_types.go Normal file
View File

@@ -0,0 +1,110 @@
package proton
import (
"encoding/base64"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
type CalendarEvent struct {
ID string
UID string
CalendarID string
SharedEventID string
CreateTime int64
LastEditTime int64
StartTime int64
StartTimezone string
EndTime int64
EndTimezone string
FullDay Bool
Author string
Permissions CalendarPermissions
Attendees []CalendarAttendee
SharedKeyPacket string
CalendarKeyPacket string
SharedEvents []CalendarEventPart
CalendarEvents []CalendarEventPart
AttendeesEvents []CalendarEventPart
PersonalEvents []CalendarEventPart
}
// TODO: Only personal events have MemberID; should we have a different type for that?
type CalendarEventPart struct {
MemberID string
Type CalendarEventType
Data string
Signature string
Author string
}
func (part CalendarEventPart) Decode(calKR *crypto.KeyRing, addrKR *crypto.KeyRing, kp []byte) error {
if part.Type&CalendarEventTypeEncrypted != 0 {
var enc *crypto.PGPMessage
if kp != nil {
raw, err := base64.StdEncoding.DecodeString(part.Data)
if err != nil {
return err
}
enc = crypto.NewPGPSplitMessage(kp, raw).GetPGPMessage()
} else {
var err error
if enc, err = crypto.NewPGPMessageFromArmored(part.Data); err != nil {
return err
}
}
dec, err := calKR.Decrypt(enc, nil, crypto.GetUnixTime())
if err != nil {
return err
}
part.Data = dec.GetString()
}
if part.Type&CalendarEventTypeSigned != 0 {
sig, err := crypto.NewPGPSignatureFromArmored(part.Signature)
if err != nil {
return err
}
if err := addrKR.VerifyDetached(crypto.NewPlainMessageFromString(part.Data), sig, crypto.GetUnixTime()); err != nil {
return err
}
}
return nil
}
type CalendarEventType int
const (
CalendarEventTypeClear CalendarEventType = iota
CalendarEventTypeEncrypted
CalendarEventTypeSigned
)
type CalendarAttendee struct {
ID string
Token string
Status CalendarAttendeeStatus
Permissions CalendarPermissions
}
// TODO: What is this?
type CalendarAttendeeStatus int
const (
CalendarAttendeeStatusPending CalendarAttendeeStatus = iota
CalendarAttendeeStatusMaybe
CalendarAttendeeStatusNo
CalendarAttendeeStatusYes
)

140
calendar_types.go Normal file
View File

@@ -0,0 +1,140 @@
package proton
import (
"errors"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
type Calendar struct {
ID string
Name string
Description string
Color string
Display Bool
Type CalendarType
Flags CalendarFlag
}
type CalendarFlag int64
const (
CalendarFlagActive CalendarFlag = 1 << iota
CalendarFlagUpdatePassphrase
CalendarFlagResetNeeded
CalendarFlagIncompleteSetup
CalendarFlagLostAccess
)
type CalendarType int
const (
CalendarTypeNormal CalendarType = iota
CalendarTypeSubscribed
)
type CalendarKey struct {
ID string
CalendarID string
PassphraseID string
PrivateKey string
Flags CalendarKeyFlag
}
func (key CalendarKey) Unlock(passphrase []byte) (*crypto.Key, error) {
lockedKey, err := crypto.NewKeyFromArmored(key.PrivateKey)
if err != nil {
return nil, err
}
return lockedKey.Unlock(passphrase)
}
type CalendarKeys []CalendarKey
func (keys CalendarKeys) Unlock(passphrase []byte) (*crypto.KeyRing, error) {
kr, err := crypto.NewKeyRing(nil)
if err != nil {
return nil, err
}
for _, key := range keys {
if k, err := key.Unlock(passphrase); err != nil {
continue
} else if err := kr.AddKey(k); err != nil {
return nil, err
}
}
return kr, nil
}
// TODO: What is this?
type CalendarKeyFlag int64
const (
CalendarKeyFlagActive CalendarKeyFlag = 1 << iota
CalendarKeyFlagPrimary
)
type CalendarMember struct {
ID string
Permissions CalendarPermissions
Email string
Color string
Display Bool
CalendarID string
}
// TODO: What is this?
type CalendarPermissions int
// TODO: Support invitations.
type CalendarPassphrase struct {
ID string
Flags CalendarPassphraseFlag
MemberPassphrases []MemberPassphrase
}
func (passphrase CalendarPassphrase) Decrypt(memberID string, addrKR *crypto.KeyRing) ([]byte, error) {
for _, passphrase := range passphrase.MemberPassphrases {
if passphrase.MemberID == memberID {
return passphrase.decrypt(addrKR)
}
}
return nil, errors.New("no such member passphrase")
}
// TODO: What is this?
type CalendarPassphraseFlag int64
type MemberPassphrase struct {
MemberID string
Passphrase string
Signature string
}
func (passphrase MemberPassphrase) decrypt(addrKR *crypto.KeyRing) ([]byte, error) {
msg, err := crypto.NewPGPMessageFromArmored(passphrase.Passphrase)
if err != nil {
return nil, err
}
sig, err := crypto.NewPGPSignatureFromArmored(passphrase.Signature)
if err != nil {
return nil, err
}
dec, err := addrKR.Decrypt(msg, nil, crypto.GetUnixTime())
if err != nil {
return nil, err
}
if err := addrKR.VerifyDetached(dec, sig, crypto.GetUnixTime()); err != nil {
return nil, err
}
return dec.GetBinary(), nil
}

205
client.go Normal file
View File

@@ -0,0 +1,205 @@
package proton
import (
"context"
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/bradenaw/juniper/xsync"
"github.com/go-resty/resty/v2"
)
// clientID is a unique identifier for a client.
var clientID uint64
// AuthHandler is given any new auths that are returned from the API due to an unexpected auth refresh.
type AuthHandler func(Auth)
// Handler is a generic function that can be registered for a certain event (e.g. deauth, API code).
type Handler func()
// Client is the proton client.
type Client struct {
m *Manager
// clientID is this client's unique ID.
clientID uint64
// attPool is the (lazy-initialized) pool of goroutines that fetch attachments.
attPool func() *Pool[string, []byte]
uid string
acc string
ref string
exp time.Time
authLock sync.RWMutex
authHandlers []AuthHandler
deauthHandlers []Handler
hookLock sync.RWMutex
deauthOnce sync.Once
}
func newClient(m *Manager, uid string) *Client {
c := &Client{
m: m,
uid: uid,
clientID: atomic.AddUint64(&clientID, 1),
}
c.attPool = xsync.Lazy(func() *Pool[string, []byte] {
return NewPool(m.attPoolSize, c.getAttachment)
})
return c
}
func (c *Client) AddAuthHandler(handler AuthHandler) {
c.hookLock.Lock()
defer c.hookLock.Unlock()
c.authHandlers = append(c.authHandlers, handler)
}
func (c *Client) AddDeauthHandler(handler Handler) {
c.hookLock.Lock()
defer c.hookLock.Unlock()
c.deauthHandlers = append(c.deauthHandlers, handler)
}
func (c *Client) AddPreRequestHook(hook resty.RequestMiddleware) {
c.hookLock.Lock()
defer c.hookLock.Unlock()
c.m.rc.OnBeforeRequest(func(rc *resty.Client, r *resty.Request) error {
if clientID, ok := ClientIDFromContext(r.Context()); !ok || clientID != c.clientID {
return nil
}
return hook(rc, r)
})
}
func (c *Client) AddPostRequestHook(hook resty.ResponseMiddleware) {
c.hookLock.Lock()
defer c.hookLock.Unlock()
c.m.rc.OnAfterResponse(func(rc *resty.Client, r *resty.Response) error {
if clientID, ok := ClientIDFromContext(r.Request.Context()); !ok || clientID != c.clientID {
return nil
}
return hook(rc, r)
})
}
func (c *Client) Close() {
c.attPool().Done()
c.authLock.Lock()
defer c.authLock.Unlock()
c.uid = ""
c.acc = ""
c.ref = ""
c.exp = time.Time{}
c.hookLock.Lock()
defer c.hookLock.Unlock()
c.authHandlers = nil
c.deauthHandlers = nil
}
func (c *Client) withAuth(acc, ref string, exp time.Time) *Client {
c.acc = acc
c.ref = ref
c.exp = exp
return c
}
func (c *Client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) error {
if _, err := c.doRes(ctx, fn); err != nil {
return err
}
return nil
}
func (c *Client) doRes(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) {
c.hookLock.RLock()
defer c.hookLock.RUnlock()
req, err := c.newReq(ctx)
if err != nil {
return nil, err
}
// Perform the request.
res, err := fn(req)
// If we receive no response, we can't do anything.
if res.RawResponse == nil {
return nil, fmt.Errorf("received no response from API: %w", err)
}
// If we receive a 401, notify deauth handlers.
if res.StatusCode() == http.StatusUnauthorized {
c.deauthOnce.Do(func() {
for _, handler := range c.deauthHandlers {
handler()
}
})
}
return res, err
}
func (c *Client) newReq(ctx context.Context) (*resty.Request, error) {
c.authLock.Lock()
defer c.authLock.Unlock()
r := c.m.r(WithClient(ctx, c.clientID))
if c.uid != "" {
r.SetHeader("x-pm-uid", c.uid)
}
if time.Now().After(c.exp) {
auth, err := c.m.authRefresh(ctx, c.uid, c.ref)
if err != nil {
return nil, err
}
c.acc = auth.AccessToken
c.ref = auth.RefreshToken
c.exp = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second)
if err := c.handleAuth(auth); err != nil {
return nil, err
}
}
if c.acc != "" {
r.SetAuthToken(c.acc)
}
return r, nil
}
func (c *Client) handleAuth(auth Auth) error {
c.hookLock.RLock()
defer c.hookLock.RUnlock()
for _, handler := range c.authHandlers {
handler(auth)
}
return nil
}

141
contact.go Normal file
View File

@@ -0,0 +1,141 @@
package proton
import (
"context"
"strconv"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetContact(ctx context.Context, contactID string) (Contact, error) {
var res struct {
Contact Contact
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/contacts/v4/" + contactID)
}); err != nil {
return Contact{}, err
}
return res.Contact, nil
}
func (c *Client) CountContacts(ctx context.Context) (int, error) {
var res struct {
Total int
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/contacts/v4")
}); err != nil {
return 0, err
}
return res.Total, nil
}
func (c *Client) CountContactEmails(ctx context.Context, email string) (int, error) {
var res struct {
Total int
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).SetQueryParam("Email", email).Get("/contacts/v4/emails")
}); err != nil {
return 0, err
}
return res.Total, nil
}
func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) {
var res struct {
Contacts []Contact
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParams(map[string]string{
"Page": strconv.Itoa(page),
"PageSize": strconv.Itoa(pageSize),
}).SetResult(&res).Get("/contacts/v4")
}); err != nil {
return nil, err
}
return res.Contacts, nil
}
func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) {
total, err := c.CountContacts(ctx)
if err != nil {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]Contact, error) {
return c.GetContacts(ctx, page, pageSize)
})
}
func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) {
var res struct {
ContactEmails []ContactEmail
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParams(map[string]string{
"Page": strconv.Itoa(page),
"PageSize": strconv.Itoa(pageSize),
"Email": email,
}).SetResult(&res).Get("/contacts/v4/emails")
}); err != nil {
return nil, err
}
return res.ContactEmails, nil
}
func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) {
total, err := c.CountContactEmails(ctx, email)
if err != nil {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) {
return c.GetContactEmails(ctx, email, page, pageSize)
})
}
func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) {
var res struct {
Responses []CreateContactsRes
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/contacts/v4")
}); err != nil {
return nil, err
}
return res.Responses, nil
}
func (c *Client) UpdateContact(ctx context.Context, contactID string, req UpdateContactReq) (Contact, error) {
var res struct {
Contact Contact
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/contacts/v4/" + contactID)
}); err != nil {
return Contact{}, err
}
return res.Contact, nil
}
func (c *Client) DeleteContacts(ctx context.Context, req DeleteContactsReq) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/contacts/v4/delete")
})
}

374
contact_card.go Normal file
View File

@@ -0,0 +1,374 @@
package proton
import (
"bytes"
"errors"
"strings"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-vcard"
)
const (
FieldPMScheme = "X-PM-SCHEME"
FieldPMSign = "X-PM-SIGN"
FieldPMEncrypt = "X-PM-ENCRYPT"
FieldPMMIMEType = "X-PM-MIMETYPE"
)
type Cards []*Card
func (c *Cards) Merge(kr *crypto.KeyRing) (vcard.Card, error) {
merged := newVCard()
for _, card := range *c {
dec, err := card.decode(kr)
if err != nil {
return nil, err
}
for k, fields := range dec {
for _, f := range fields {
merged.Add(k, f)
}
}
}
return merged, nil
}
func (c *Cards) Get(cardType CardType) (*Card, bool) {
for _, card := range *c {
if card.Type == cardType {
return card, true
}
}
return nil, false
}
type Card struct {
Type CardType
Data string
Signature string
}
type CardType int
const (
CardTypeClear CardType = iota
CardTypeEncrypted
CardTypeSigned
)
func NewCard(kr *crypto.KeyRing, cardType CardType) (*Card, error) {
card := &Card{Type: cardType}
if err := card.encode(kr, newVCard()); err != nil {
return nil, err
}
return card, nil
}
func newVCard() vcard.Card {
card := make(vcard.Card)
card.AddValue(vcard.FieldVersion, "4.0")
return card
}
func (c Card) Get(kr *crypto.KeyRing, key string) ([]*vcard.Field, error) {
dec, err := c.decode(kr)
if err != nil {
return nil, err
}
return dec[key], nil
}
func (c *Card) Set(kr *crypto.KeyRing, key, value string) error {
dec, err := c.decode(kr)
if err != nil {
return err
}
if field := dec.Get(key); field != nil {
field.Value = value
return c.encode(kr, dec)
}
dec.AddValue(key, value)
return c.encode(kr, dec)
}
func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error {
dec, err := c.decode(kr)
if err != nil {
return err
}
c.Type = cardType
return c.encode(kr, dec)
}
// GetGroup returns a type to manipulate the group defined by the given key/value pair.
func (c Card) GetGroup(kr *crypto.KeyRing, groupKey, groupValue string) (CardGroup, error) {
group, err := c.getGroup(kr, groupKey, groupValue)
if err != nil {
return CardGroup{}, err
}
return CardGroup{Card: c, kr: kr, group: group}, nil
}
// DeleteGroup removes all values in the group defined by the given key/value pair.
func (c *Card) DeleteGroup(kr *crypto.KeyRing, groupKey, groupValue string) error {
group, err := c.getGroup(kr, groupKey, groupValue)
if err != nil {
return err
}
return c.deleteGroup(kr, group)
}
type CardGroup struct {
Card
kr *crypto.KeyRing
group string
}
// Get returns the values in the group with the given key.
func (g CardGroup) Get(key string) ([]string, error) {
dec, err := g.decode(g.kr)
if err != nil {
return nil, err
}
var fields []*vcard.Field
for _, field := range dec[key] {
if field.Group != g.group {
continue
}
fields = append(fields, field)
}
return xslices.Map(fields, func(field *vcard.Field) string {
return field.Value
}), nil
}
// Set sets the value in the group.
func (g *CardGroup) Set(key, value string, params vcard.Params) error {
dec, err := g.decode(g.kr)
if err != nil {
return err
}
for _, field := range dec[key] {
if field.Group != g.group {
continue
}
field.Value = value
return g.encode(g.kr, dec)
}
dec.Add(key, &vcard.Field{
Value: value,
Group: g.group,
Params: params,
})
return g.encode(g.kr, dec)
}
// Add adds a value to the group.
func (g *CardGroup) Add(key, value string, params vcard.Params) error {
dec, err := g.decode(g.kr)
if err != nil {
return err
}
dec.Add(key, &vcard.Field{
Value: value,
Group: g.group,
Params: params,
})
return g.encode(g.kr, dec)
}
// Remove removes the value in the group with the given key/value.
func (g *CardGroup) Remove(key, value string) error {
dec, err := g.decode(g.kr)
if err != nil {
return err
}
fields, ok := dec[key]
if !ok {
return errors.New("no such key")
}
var rest []*vcard.Field
for _, field := range fields {
if field.Group != g.group {
rest = append(rest, field)
} else if field.Value != value {
rest = append(rest, field)
}
}
if len(rest) > 0 {
dec[key] = rest
} else {
delete(dec, key)
}
return g.encode(g.kr, dec)
}
// RemoveAll removes all values in the group with the given key.
func (g *CardGroup) RemoveAll(key string) error {
dec, err := g.decode(g.kr)
if err != nil {
return err
}
fields, ok := dec[key]
if !ok {
return errors.New("no such key")
}
var rest []*vcard.Field
for _, field := range fields {
if field.Group != g.group {
rest = append(rest, field)
}
}
if len(rest) > 0 {
dec[key] = rest
} else {
delete(dec, key)
}
return g.encode(g.kr, dec)
}
func (c Card) getGroup(kr *crypto.KeyRing, groupKey, groupValue string) (string, error) {
fields, err := c.Get(kr, groupKey)
if err != nil {
return "", err
}
for _, field := range fields {
if field.Value != groupValue {
continue
}
return field.Group, nil
}
return "", errors.New("no such field")
}
func (c *Card) deleteGroup(kr *crypto.KeyRing, group string) error {
dec, err := c.decode(kr)
if err != nil {
return err
}
for key, fields := range dec {
var rest []*vcard.Field
for _, field := range fields {
if field.Group != group {
rest = append(rest, field)
}
}
if len(rest) > 0 {
dec[key] = rest
} else {
delete(dec, key)
}
}
return c.encode(kr, dec)
}
func (c Card) decode(kr *crypto.KeyRing) (vcard.Card, error) {
if c.Type&CardTypeEncrypted != 0 {
enc, err := crypto.NewPGPMessageFromArmored(c.Data)
if err != nil {
return nil, err
}
dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime())
if err != nil {
return nil, err
}
c.Data = dec.GetString()
}
if c.Type&CardTypeSigned != 0 {
sig, err := crypto.NewPGPSignatureFromArmored(c.Signature)
if err != nil {
return nil, err
}
if err := kr.VerifyDetached(crypto.NewPlainMessageFromString(c.Data), sig, crypto.GetUnixTime()); err != nil {
return nil, err
}
}
return vcard.NewDecoder(strings.NewReader(c.Data)).Decode()
}
func (c *Card) encode(kr *crypto.KeyRing, card vcard.Card) error {
buf := new(bytes.Buffer)
if err := vcard.NewEncoder(buf).Encode(card); err != nil {
return err
}
if c.Type&CardTypeSigned != 0 {
sig, err := kr.SignDetached(crypto.NewPlainMessageFromString(buf.String()))
if err != nil {
return err
}
if c.Signature, err = sig.GetArmored(); err != nil {
return err
}
}
if c.Type&CardTypeEncrypted != 0 {
enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(buf.String()), nil)
if err != nil {
return err
}
if c.Data, err = enc.GetArmored(); err != nil {
return err
}
} else {
c.Data = buf.String()
}
return nil
}

171
contact_types.go Normal file
View File

@@ -0,0 +1,171 @@
package proton
import (
"encoding/base64"
"strconv"
"strings"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/emersion/go-vcard"
)
type RecipientType int
const (
RecipientTypeInternal RecipientType = iota + 1
RecipientTypeExternal
)
type ContactSettings struct {
MIMEType *rfc822.MIMEType
Scheme *EncryptionScheme
Sign *bool
Encrypt *bool
Keys []*crypto.Key
}
type Contact struct {
ContactMetadata
ContactCards
}
func (c *Contact) GetSettings(kr *crypto.KeyRing, email string) (ContactSettings, error) {
signedCard, ok := c.Cards.Get(CardTypeSigned)
if !ok {
return ContactSettings{}, nil
}
group, err := signedCard.GetGroup(kr, vcard.FieldEmail, email)
if err != nil {
return ContactSettings{}, nil
}
var settings ContactSettings
scheme, err := group.Get(FieldPMScheme)
if err != nil {
return ContactSettings{}, err
}
if len(scheme) > 0 {
switch scheme[0] {
case "pgp-inline":
settings.Scheme = newPtr(PGPInlineScheme)
case "pgp-mime":
settings.Scheme = newPtr(PGPMIMEScheme)
}
}
mimeType, err := group.Get(FieldPMMIMEType)
if err != nil {
return ContactSettings{}, err
}
if len(mimeType) > 0 {
settings.MIMEType = newPtr(rfc822.MIMEType(mimeType[0]))
}
sign, err := group.Get(FieldPMSign)
if err != nil {
return ContactSettings{}, err
}
if len(sign) > 0 {
sign, err := strconv.ParseBool(sign[0])
if err != nil {
return ContactSettings{}, err
}
settings.Sign = newPtr(sign)
}
encrypt, err := group.Get(FieldPMEncrypt)
if err != nil {
return ContactSettings{}, err
}
if len(encrypt) > 0 {
encrypt, err := strconv.ParseBool(encrypt[0])
if err != nil {
return ContactSettings{}, err
}
settings.Encrypt = newPtr(encrypt)
}
keys, err := group.Get(vcard.FieldKey)
if err != nil {
return ContactSettings{}, err
}
if len(keys) > 0 {
for _, key := range keys {
dec, err := base64.StdEncoding.DecodeString(strings.SplitN(key, ",", 2)[1])
if err != nil {
return ContactSettings{}, err
}
pubKey, err := crypto.NewKey(dec)
if err != nil {
return ContactSettings{}, err
}
settings.Keys = append(settings.Keys, pubKey)
}
}
return settings, nil
}
type ContactMetadata struct {
ID string
Name string
UID string
Size int64
CreateTime int64
ModifyTime int64
ContactEmails []ContactEmail
LabelIDs []string
}
type ContactCards struct {
Cards Cards
}
type ContactEmail struct {
ID string
Name string
Email string
Type []string
ContactID string
LabelIDs []string
}
type CreateContactsReq struct {
Contacts []ContactCards
Overwrite int
Labels int
}
type CreateContactsRes struct {
Index int
Response struct {
Error
Contact Contact
}
}
type UpdateContactReq struct {
Cards Cards
}
type DeleteContactsReq struct {
IDs []string
}
func newPtr[T any](v T) *T {
return &v
}

22
contexts.go Normal file
View File

@@ -0,0 +1,22 @@
package proton
import "context"
type withClientKeyType struct{}
var withClientKey withClientKeyType
// WithClient marks this context as originating from the client with the given ID.
func WithClient(parent context.Context, clientID uint64) context.Context {
return context.WithValue(parent, withClientKey, clientID)
}
// ClientIDFromContext returns true if this context was marked as originating from a client.
func ClientIDFromContext(ctx context.Context) (uint64, bool) {
clientID, ok := ctx.Value(withClientKey).(uint64)
if !ok {
return 0, false
}
return clientID, true
}

311
dialer.go Normal file
View File

@@ -0,0 +1,311 @@
package proton
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
)
func InsecureTransport() *http.Transport {
return &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
}
// NetCtl can be used to control whether a dialer can dial, and whether the resulting
// connection can read or write.
type NetCtl struct {
canDial atomicBool
dialLimit atomicUint64
canRead atomicBool
readLimit atomicUint64
canWrite atomicBool
writeLimit atomicUint64
onDial []func(net.Conn)
onRead []func([]byte)
onWrite []func([]byte)
lock sync.Mutex
}
// NewNetCtl returns a new NetCtl with all fields set to true.
func NewNetCtl() *NetCtl {
return &NetCtl{
canDial: atomicBool{b32(true)},
canRead: atomicBool{b32(true)},
canWrite: atomicBool{b32(true)},
}
}
// SetCanDial sets whether the dialer can dial.
func (c *NetCtl) SetCanDial(canDial bool) {
c.canDial.Store(canDial)
}
// SetDialLimit sets the maximum number of times dialers using this controller can dial.
func (c *NetCtl) SetDialLimit(limit uint64) {
c.dialLimit.Store(limit)
}
// SetCanRead sets whether the connection can read.
func (c *NetCtl) SetCanRead(canRead bool) {
c.canRead.Store(canRead)
}
// SetReadLimit sets the maximum number of bytes that can be read.
func (c *NetCtl) SetReadLimit(limit uint64) {
c.readLimit.Store(limit)
}
// SetCanWrite sets whether the connection can write.
func (c *NetCtl) SetCanWrite(canWrite bool) {
c.canWrite.Store(canWrite)
}
// SetWriteLimit sets the maximum number of bytes that can be written.
func (c *NetCtl) SetWriteLimit(limit uint64) {
c.writeLimit.Store(limit)
}
// OnDial adds a callback that is called with the created connection when a dial is successful.
func (c *NetCtl) OnDial(f func(net.Conn)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onDial = append(c.onDial, f)
}
// OnRead adds a callback that is called with the read bytes when a read is successful.
func (c *NetCtl) OnRead(f func([]byte)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onRead = append(c.onRead, f)
}
// OnWrite adds a callback that is called with the written bytes when a write is successful.
func (c *NetCtl) OnWrite(f func([]byte)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onWrite = append(c.onWrite, f)
}
// Disable is equivalent to disallowing dial, read and write.
func (c *NetCtl) Disable() {
c.SetCanDial(false)
c.SetCanRead(false)
c.SetCanWrite(false)
}
// Enable is equivalent to allowing dial, read and write.
func (c *NetCtl) Enable() {
c.SetCanDial(true)
c.SetCanRead(true)
c.SetCanWrite(true)
}
// Conn is a wrapper around net.Conn that can be used to control whether a connection can read or write.
type Conn struct {
net.Conn
ctl *NetCtl
readLimiter *readLimiter
writeLimiter *writeLimiter
}
// Read reads from the wrapped connection, but only if the controller allows it.
func (c *Conn) Read(b []byte) (int, error) {
if !c.ctl.canRead.Load() {
return 0, errors.New("cannot read")
}
n, err := c.readLimiter.read(c.Conn, b)
if err != nil {
return n, err
}
for _, f := range c.ctl.onRead {
f(b[:n])
}
return n, err
}
// Write writes to the wrapped connection, but only if the controller allows it.
func (c *Conn) Write(b []byte) (int, error) {
if !c.ctl.canWrite.Load() {
return 0, errors.New("cannot write")
}
n, err := c.writeLimiter.write(c.Conn, b)
if err != nil {
return n, err
}
for _, f := range c.ctl.onWrite {
f(b[:n])
}
return n, err
}
// Dialer performs network dialing, but only if the controller allows it.
type Dialer struct {
ctl *NetCtl
netDialer *net.Dialer
tlsDialer *tls.Dialer
tlsConfig *tls.Config
readLimiter *readLimiter
writeLimiter *writeLimiter
dialCount atomicUint64
}
// NewDialer returns a new dialer using the given net controller.
// It optionally uses a provided tls config.
func NewDialer(ctl *NetCtl, tlsConfig *tls.Config) *Dialer {
return &Dialer{
ctl: ctl,
netDialer: &net.Dialer{},
tlsDialer: &tls.Dialer{Config: tlsConfig},
tlsConfig: tlsConfig,
readLimiter: newReadLimiter(ctl),
writeLimiter: newWriteLimiter(ctl),
dialCount: atomicUint64{0},
}
}
// DialContext dials a network connection, but only if the controller allows it.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dialWithDialer(ctx, network, addr, d.netDialer)
}
// DialTLSContext dials a TLS network connection, but only if the controller allows it.
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dialWithDialer(ctx, network, addr, d.tlsDialer)
}
// dialWithDialer dials a network connection using the given dialer, but only if the controller allows it.
func (d *Dialer) dialWithDialer(ctx context.Context, network, addr string, dialer dialer) (net.Conn, error) {
if !d.ctl.canDial.Load() {
return nil, errors.New("cannot dial")
}
if limit := d.ctl.dialLimit.Load(); limit > 0 && d.dialCount.Load() >= limit {
return nil, errors.New("dial limit reached")
} else {
d.dialCount.Add(1)
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
d.ctl.lock.Lock()
defer d.ctl.lock.Unlock()
for _, f := range d.ctl.onDial {
f(conn)
}
return &Conn{
Conn: conn,
ctl: d.ctl,
readLimiter: d.readLimiter,
writeLimiter: d.writeLimiter,
}, nil
}
// GetRoundTripper returns a new http.RoundTripper that uses the dialer.
func (d *Dialer) GetRoundTripper() http.RoundTripper {
return &http.Transport{
DialContext: d.DialContext,
DialTLSContext: d.DialTLSContext,
TLSClientConfig: d.tlsConfig,
}
}
type dialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
type readLimiter struct {
ctl *NetCtl
count atomicUint64
}
// newReadLimiter returns a new io.Reader that reads from r, but only up to limit bytes.
func newReadLimiter(ctl *NetCtl) *readLimiter {
return &readLimiter{
ctl: ctl,
}
}
func (limiter *readLimiter) read(r io.Reader, b []byte) (int, error) {
if limit := limiter.ctl.readLimit.Load(); limit > 0 && limiter.count.Load() >= limit {
return 0, fmt.Errorf("refusing to read: read limit reached")
}
n, err := r.Read(b)
if err != nil {
return n, err
}
if limit := limiter.ctl.readLimit.Load(); limit > 0 {
if new := limiter.count.Add(uint64(n)); new >= limit {
return 0, fmt.Errorf("read failed: read limit reached")
}
}
return n, err
}
type writeLimiter struct {
ctl *NetCtl
count atomicUint64
}
// newWriteLimiter returns a new io.Writer that writes to w, but only up to limit bytes.
func newWriteLimiter(ctl *NetCtl) *writeLimiter {
return &writeLimiter{
ctl: ctl,
}
}
func (limiter *writeLimiter) write(w io.Writer, b []byte) (int, error) {
if limit := limiter.ctl.writeLimit.Load(); limit > 0 && limiter.count.Load() >= limit {
return 0, fmt.Errorf("refusing to write: write limit reached")
}
n, err := w.Write(b)
if err != nil {
return n, err
}
if limit := limiter.ctl.writeLimit.Load(); limit > 0 {
if new := limiter.count.Add(uint64(n)); new >= limit {
return 0, fmt.Errorf("write failed: write limit reached")
}
}
return n, err
}

79
dialer_test.go Normal file
View File

@@ -0,0 +1,79 @@
package proton_test
import (
"bytes"
"crypto/tls"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/ProtonMail/go-proton-api"
)
func TestNetCtl_ReadLimit(t *testing.T) {
// Create a test http server that writes 100 bytes.
// Including the header, this is 217 bytes (100 bytes + 117 bytes).
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write(make([]byte, 100)); err != nil {
t.Fatal(err)
}
}))
defer ts.Close()
// Create a new net controller.
netCtl := proton.NewNetCtl()
// Create a new http client with the dialer.
client := &http.Client{
Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
}
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
netCtl.SetReadLimit(300)
// This should succeed.
if resp, err := client.Get(ts.URL); err != nil {
t.Fatal(err)
} else {
resp.Body.Close()
}
// This should fail.
if _, err := client.Get(ts.URL); err == nil {
t.Fatal("expected error")
}
}
func TestNetCtl_WriteLimit(t *testing.T) {
// Create a test http server that reads the given body.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := io.ReadAll(r.Body); err != nil {
t.Fatal(err)
}
}))
defer ts.Close()
// Create a new net controller.
netCtl := proton.NewNetCtl()
// Create a new http client with the dialer.
client := &http.Client{
Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
}
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
netCtl.SetWriteLimit(300)
// This should succeed.
if resp, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err != nil {
t.Fatal(err)
} else {
resp.Body.Close()
}
// This should fail.
if _, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err == nil {
t.Fatal("expected error")
}
}

204
drive_types.go Normal file
View File

@@ -0,0 +1,204 @@
package proton
import "github.com/ProtonMail/gopenpgp/v2/crypto"
type Volume struct {
ID string // Encrypted volume ID
Name string // The volume name
OwnerUserID string // Encrypted owner user ID
UsedSpace int64 // Space used by files in the volume in bytes
MaxSpace int64 // Space limit for the volume in bytes
State VolumeState // TODO: What is this?
}
type VolumeState int
const (
// TODO: VolumeState constants
)
type Share struct {
ShareID string // Encrypted share ID
Type ShareType // Type of share
State ShareState // TODO: What is this?
PermissionsMask Permissions // Mask restricting member permissions on the share
LinkID string // Encrypted link ID to which the share points (root of share).
LinkType LinkType // TODO: What is this?
VolumeID string // Encrypted volume ID on which the share is mounted
Creator string // Creator address
AddressID string
Flags ShareFlags // The flag bitmap, with the following values
BlockSize int64 // TODO: What is this?
Locked bool // TODO: What is this?
Key string // The private key, encrypted with a passphrase
Passphrase string // The encrypted passphrase
PassphraseSignature string // The signature of the passphrase
}
func (s Share) GetKeyRing(kr *crypto.KeyRing) (*crypto.KeyRing, error) {
encPass, err := crypto.NewPGPMessageFromArmored(s.Passphrase)
if err != nil {
return nil, err
}
decPass, err := kr.Decrypt(encPass, nil, crypto.GetUnixTime())
if err != nil {
return nil, err
}
lockedKey, err := crypto.NewKeyFromArmored(s.Key)
if err != nil {
return nil, err
}
unlockedKey, err := lockedKey.Unlock(decPass.GetBinary())
if err != nil {
return nil, err
}
return crypto.NewKeyRing(unlockedKey)
}
type ShareType int
const (
// TODO: ShareType constants
)
type ShareState int
const (
// TODO: ShareState constants
)
type ShareFlags int
const (
NoFlags ShareFlags = iota
PrimaryShare
)
type Link struct {
LinkID string // Encrypted file/folder ID
ParentLinkID string // Encrypted parent folder ID (LinkID)
Type LinkType
Name string // Encrypted file name
Hash string // HMAC of name encrypted with parent hash key
State LinkState // State of the link
ExpirationTime int64
Size int64
MIMEType string
Attributes Attributes
Permissions Permissions
NodeKey string
NodePassphrase string
NodePassphraseSignature string
SignatureAddress string
CreateTime int64
ModifyTime int64
FileProperties FileProperties
FolderProperties FolderProperties
}
func (l Link) GetKeyRing(kr *crypto.KeyRing) (*crypto.KeyRing, error) {
encPass, err := crypto.NewPGPMessageFromArmored(l.NodePassphrase)
if err != nil {
return nil, err
}
decPass, err := kr.Decrypt(encPass, nil, crypto.GetUnixTime())
if err != nil {
return nil, err
}
lockedKey, err := crypto.NewKeyFromArmored(l.NodeKey)
if err != nil {
return nil, err
}
unlockedKey, err := lockedKey.Unlock(decPass.GetBinary())
if err != nil {
return nil, err
}
return crypto.NewKeyRing(unlockedKey)
}
type FileProperties struct {
ContentKeyPacket string
ActiveRevision Revision
}
type FolderProperties struct{}
type LinkType int
const (
FolderLinkType LinkType = iota + 1
FileLinkType
)
type LinkState int
const (
DraftLinkState LinkState = iota
ActiveLinkState
TrashedLinkState
DeletedLinkState
)
type Revision struct {
ID string // Encrypted Revision ID
CreateTime int64 // Unix timestamp of the revision creation time
Size int64 // Size of the file in bytes
ManifestSignature string // The signature of the root hash
SignatureAddress string // The address used to sign the root hash
State FileRevisionState // State of revision
Blocks []Block
}
type FileRevisionState int
const (
DraftRevisionState FileRevisionState = iota
ActiveRevisionState
ObsoleteRevisionState
)
type Block struct {
Index int
URL string
EncSignature string
SignatureEmail string
}
type LinkEvent struct {
EventID string // Encrypted ID of the Event
CreateTime int64 // Time stamp of the creation time of the Event
EventType LinkEventType // Type of event
}
type LinkEventType int
const (
DeleteLinkEvent LinkEventType = iota
CreateLinkEvent
UpdateContentsLinkEvent
UpdateMetadataLinkEvent
)
type Permissions int
const (
NoPermissions Permissions = 1 << iota
ReadPermission
WritePermission
AdministerMembersPermission
AdminPermission
SuperAdminPermission
)
type Attributes uint32

102
event.go Normal file
View File

@@ -0,0 +1,102 @@
package proton
import (
"context"
"time"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetLatestEventID(ctx context.Context) (string, error) {
var res struct {
Event
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/core/v4/events/latest")
}); err != nil {
return "", err
}
return res.EventID, nil
}
func (c *Client) GetEvent(ctx context.Context, eventID string) (Event, error) {
event, more, err := c.getEvent(ctx, eventID)
if err != nil {
return Event{}, err
}
for more {
var next Event
next, more, err = c.getEvent(ctx, event.EventID)
if err != nil {
return Event{}, err
}
if err := event.merge(next); err != nil {
return Event{}, err
}
}
return event, nil
}
// NewEventStreamer returns a new event stream.
// It polls the API for new events at random intervals between `period` and `period+jitter`.
func (c *Client) NewEventStream(ctx context.Context, period, jitter time.Duration, lastEventID string) <-chan Event {
eventCh := make(chan Event)
go func() {
defer close(eventCh)
ticker := NewTicker(period, jitter)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// ...
}
event, err := c.GetEvent(ctx, lastEventID)
if err != nil {
continue
}
if event.EventID == lastEventID {
continue
}
select {
case <-ctx.Done():
return
case eventCh <- event:
lastEventID = event.EventID
}
}
}()
return eventCh
}
func (c *Client) getEvent(ctx context.Context, eventID string) (Event, bool, error) {
var res struct {
Event
More Bool
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/core/v4/events/" + eventID)
}); err != nil {
return Event{}, false, err
}
return res.Event, bool(res.More), nil
}

70
event_test.go Normal file
View File

@@ -0,0 +1,70 @@
package proton_test
import (
"context"
"testing"
"time"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/stretchr/testify/require"
)
func TestEventStreamer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := server.New()
defer s.Close()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.InsecureTransport()),
)
_, _, err := s.CreateUser("username", "email@pm.me", []byte("password"))
require.NoError(t, err)
c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))
require.NoError(t, err)
createTestMessages(t, c, "password", 10)
latestEventID, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
eventCh := make(chan proton.Event)
go func() {
for event := range c.NewEventStream(ctx, time.Second, 0, latestEventID) {
eventCh <- event
}
}()
// Perform some action to generate an event.
metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{})
require.NoError(t, err)
require.NoError(t, c.LabelMessages(ctx, []string{metadata[0].ID}, proton.TrashLabel))
// Wait for the first event.
<-eventCh
// Close the client; this should stop the client's event streamer.
c.Close()
// Create a new client and perform some actions with it to generate more events.
cc, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))
require.NoError(t, err)
defer cc.Close()
require.NoError(t, cc.LabelMessages(ctx, []string{metadata[1].ID}, proton.TrashLabel))
// We should not receive any more events from the original client.
select {
case <-eventCh:
require.Fail(t, "received unexpected event")
default:
// ...
}
}

140
event_types.go Normal file
View File

@@ -0,0 +1,140 @@
package proton
import (
"fmt"
"strings"
"github.com/bradenaw/juniper/xslices"
)
type Event struct {
EventID string
Refresh RefreshFlag
User *User
MailSettings *MailSettings
Messages []MessageEvent
Labels []LabelEvent
Addresses []AddressEvent
}
func (event Event) String() string {
var parts []string
if event.Refresh != 0 {
parts = append(parts, fmt.Sprintf("refresh: %v", event.Refresh))
}
if event.User != nil {
parts = append(parts, "user: [modified]")
}
if event.MailSettings != nil {
parts = append(parts, "mail-settings: [modified]")
}
if len(event.Messages) > 0 {
parts = append(parts, fmt.Sprintf(
"messages: created=%d, updated=%d, deleted=%d",
xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventCreate }),
xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }),
xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventDelete }),
))
}
if len(event.Labels) > 0 {
parts = append(parts, fmt.Sprintf(
"labels: created=%d, updated=%d, deleted=%d",
xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventCreate }),
xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }),
xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventDelete }),
))
}
if len(event.Addresses) > 0 {
parts = append(parts, fmt.Sprintf(
"addresses: created=%d, updated=%d, deleted=%d",
xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventCreate }),
xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }),
xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventDelete }),
))
}
return fmt.Sprintf("Event %s: %s", event.EventID, strings.Join(parts, ", "))
}
// merge combines this event with the other event (assumed to be newer!).
// TODO: Intelligent merging: if there are multiple EventUpdate(Flags) events, can we just take the latest one?
func (event *Event) merge(other Event) error {
event.EventID = other.EventID
if other.User != nil {
event.User = other.User
}
if other.MailSettings != nil {
event.MailSettings = other.MailSettings
}
// For now, label events are simply appended.
event.Labels = append(event.Labels, other.Labels...)
// For now, message events are simply appended.
event.Messages = append(event.Messages, other.Messages...)
// For now, address events are simply appended.
event.Addresses = append(event.Addresses, other.Addresses...)
return nil
}
type RefreshFlag uint8
const (
RefreshMail RefreshFlag = 1 << iota // 1<<0 = 1
_ // 1<<1 = 2
_ // 1<<2 = 4
_ // 1<<3 = 8
_ // 1<<4 = 16
_ // 1<<5 = 32
_ // 1<<6 = 64
_ // 1<<7 = 128
RefreshAll RefreshFlag = 1<<iota - 1 // 1<<8 - 1 = 255
)
type EventAction int
const (
EventDelete EventAction = iota
EventCreate
EventUpdate
EventUpdateFlags
)
type EventItem struct {
ID string
Action EventAction
}
type MessageEvent struct {
EventItem
Message MessageMetadata
}
type LabelEvent struct {
EventItem
Label Label
}
type AddressEvent struct {
EventItem
Address Address
}

122
example_test.go Normal file
View File

@@ -0,0 +1,122 @@
package proton_test
import (
"context"
"fmt"
"time"
"github.com/ProtonMail/go-proton-api"
)
func ExampleManager_NewClient() {
// Create a new manager.
m := proton.New()
// If auth information is already known, it can be used to create a client straight away.
c := m.NewClient("...uid...", "...acc...", "...ref...", time.Now().Add(time.Hour))
defer c.Close()
// All API operations must be run within a context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Do something with the client.
if _, err := c.GetUser(ctx); err != nil {
panic(err)
}
}
func ExampleManager_NewClientWithRefresh() {
// Create a new manager.
m := proton.New()
// All API operations must be run within a context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// If UID/RefreshToken is already known, it can be used to create a new client straight away.
c, _, err := m.NewClientWithRefresh(ctx, "...uid...", "...ref...")
if err != nil {
panic(err)
}
defer c.Close()
// Do something with the client.
if _, err := c.GetUser(ctx); err != nil {
panic(err)
}
}
func ExampleManager_NewClientWithLogin() {
// Create a new manager.
m := proton.New()
// All API operations must be run within a context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Clients are created via username/password if auth information isn't already known.
c, auth, err := m.NewClientWithLogin(ctx, "...user...", []byte("...pass..."))
if err != nil {
panic(err)
}
defer c.Close()
// If 2FA is necessary, an additional request is required.
if auth.TwoFA.Enabled == proton.TOTPEnabled {
if err := c.Auth2FA(ctx, proton.Auth2FAReq{TwoFactorCode: "...TOTP..."}); err != nil {
panic(err)
}
}
// Do something with the client.
if _, err := c.GetUser(ctx); err != nil {
panic(err)
}
}
func ExampleClient_AddAuthHandler() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a new manager.
m := proton.New()
// Create a new client.
c := m.NewClient("...uid...", "...acc...", "...ref...", time.Now().Add(time.Hour))
defer c.Close()
// Register an auth handler with the client.
// This could be used for example to save the auth to keychain.
c.AddAuthHandler(func(auth proton.Auth) {
// Do something with auth.
})
if _, err := c.GetUser(ctx); err != nil {
panic(err)
}
}
func ExampleClient_NewEventStream() {
m := proton.New()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c, _, err := m.NewClientWithLogin(ctx, "...user...", []byte("...pass..."))
if err != nil {
panic(err)
}
defer c.Close()
// Get the latest event ID.
fromEventID, err := c.GetLatestEventID(context.Background())
if err != nil {
panic(err)
}
// Create a new event streamer.
for event := range c.NewEventStream(ctx, 20*time.Second, 20*time.Second, fromEventID) {
fmt.Println(event.EventID)
}
}

78
future.go Normal file
View File

@@ -0,0 +1,78 @@
package proton
type Future[T any] struct {
resCh chan res[T]
}
type res[T any] struct {
val T
err error
}
func NewFuture[T any](fn func() (T, error)) *Future[T] {
resCh := make(chan res[T])
go func() {
val, err := fn()
resCh <- res[T]{val: val, err: err}
}()
return &Future[T]{resCh: resCh}
}
func (job *Future[T]) Then(fn func(T, error)) {
go func() {
res := <-job.resCh
fn(res.val, res.err)
}()
}
func (job *Future[T]) Get() (T, error) {
res := <-job.resCh
return res.val, res.err
}
type Group[T any] struct {
futures []*Future[T]
}
func NewGroup[T any]() *Group[T] {
return &Group[T]{}
}
func (group *Group[T]) Add(fn func() (T, error)) {
group.futures = append(group.futures, NewFuture(fn))
}
func (group *Group[T]) Result() ([]T, error) {
var out []T
for _, future := range group.futures {
res, err := future.Get()
if err != nil {
return nil, err
}
out = append(out, res)
}
return out, nil
}
func (group *Group[T]) ForEach(fn func(T) error) error {
for _, future := range group.futures {
res, err := future.Get()
if err != nil {
return err
}
if err := fn(res); err != nil {
return err
}
}
return nil
}

48
future_test.go Normal file
View File

@@ -0,0 +1,48 @@
package proton
import (
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestFuture(t *testing.T) {
resCh := make(chan int)
NewFuture(func() (int, error) {
return 42, nil
}).Then(func(res int, err error) {
resCh <- res
})
require.Equal(t, 42, <-resCh)
}
func TestGroup(t *testing.T) {
group := NewGroup[int]()
for i := 0; i < 10; i++ {
i := i
group.Add(func() (int, error) {
// Sleep a random amount of time so that results are returned in a random order.
time.Sleep(time.Duration(rand.Int()%10) * time.Millisecond) //nolint:gosec
// Return the job index [0, 10].
return i, nil
})
}
resCh := make(chan int)
go func() {
require.Equal(t, group.ForEach(func(res int) error { resCh <- res; return nil }), nil)
}()
// Results should be returned in the original order.
for i := 0; i < 10; i++ {
require.Equal(t, i, <-resCh)
}
}

60
go.mod Normal file
View File

@@ -0,0 +1,60 @@
module github.com/ProtonMail/go-proton-api
go 1.18
require (
github.com/Masterminds/semver/v3 v3.1.1
github.com/ProtonMail/gluon v0.13.1-0.20221025093924-86bbf0261eb8
github.com/ProtonMail/go-crypto v0.0.0-20220824120805-4b6e5c587895
github.com/ProtonMail/go-srp v0.0.5
github.com/ProtonMail/gopenpgp/v2 v2.4.10
github.com/bradenaw/juniper v0.8.0
github.com/emersion/go-message v0.16.0
github.com/emersion/go-vcard v0.0.0-20220507122617-d4056df0ec4a
github.com/gin-gonic/gin v1.8.1
github.com/go-resty/resty/v2 v2.7.0
github.com/google/go-cmp v0.5.8
github.com/google/uuid v1.3.0
github.com/stretchr/testify v1.8.0
github.com/urfave/cli/v2 v2.20.3
go.uber.org/goleak v1.1.12
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91
google.golang.org/grpc v1.50.1
google.golang.org/protobuf v1.28.0
)
require (
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect
github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f // indirect
github.com/cloudflare/circl v1.2.0 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-playground/validator/v10 v10.10.0 // indirect
github.com/goccy/go-json v0.9.7 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/ugorji/go/codec v1.2.7 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b // indirect
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.13-0.20220804200503-81c7dc4e4efa // indirect
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

258
go.sum Normal file
View File

@@ -0,0 +1,258 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
github.com/ProtonMail/bcrypt v0.0.0-20210511135022-227b4adcab57/go.mod h1:HecWFHognK8GfRDGnFQbW/LiV7A3MX3gZVs45vk5h8I=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf h1:yc9daCCYUefEs69zUkSzubzjBbL+cmOXgnmt9Fyd9ug=
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf/go.mod h1:o0ESU9p83twszAU8LBeJKFAAMX14tISa0yk4Oo5TOqo=
github.com/ProtonMail/gluon v0.13.1-0.20221025093924-86bbf0261eb8 h1:LKyiQdEsAxAocSYUWxSfwlxBwmzJYvO/9td/eAX3oFU=
github.com/ProtonMail/gluon v0.13.1-0.20221025093924-86bbf0261eb8/go.mod h1:XW/gcr4jErc5bX5yMqkUq3U+AucC2QZHJ5L231k3Nw4=
github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/9nQmJSSwwds7ejkxaJwO37dru3geImFUdJlaLzQo=
github.com/ProtonMail/go-crypto v0.0.0-20220822140716-1678d6eb0cbe/go.mod h1:UBYPn8k0D56RtnR8RFQMjmh4KrZzWJ5o7Z9SYjossQ8=
github.com/ProtonMail/go-crypto v0.0.0-20220824120805-4b6e5c587895 h1:NsReiLpErIPzRrnogAXYwSoU7txA977LjDGrbkewJbg=
github.com/ProtonMail/go-crypto v0.0.0-20220824120805-4b6e5c587895/go.mod h1:UBYPn8k0D56RtnR8RFQMjmh4KrZzWJ5o7Z9SYjossQ8=
github.com/ProtonMail/go-mime v0.0.0-20220302105931-303f85f7fe0f/go.mod h1:NYt+V3/4rEeDuaev/zw1zCq8uqVEuPHzDPo3OZrlGJ4=
github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f h1:4IWzKjHzZxdrW9k4zl/qCwenOVHDbVDADPPHFLjs0Oc=
github.com/ProtonMail/go-mime v0.0.0-20220429130430-2192574d760f/go.mod h1:qRZgbeASl2a9OwmsV85aWwRqic0NHPh+9ewGAzb4cgM=
github.com/ProtonMail/go-srp v0.0.5 h1:xhUioxZgDbCnpo9JehyFhwwsn9JLWkUGfB0oiKXgiGg=
github.com/ProtonMail/go-srp v0.0.5/go.mod h1:06iYHtLXW8vjLtccWj++x3MKy65sIT8yZd7nrJF49rs=
github.com/ProtonMail/gopenpgp/v2 v2.4.10 h1:EYgkxzwmQvsa6kxxkgP1AwzkFqKHscF2UINxaSn6rdI=
github.com/ProtonMail/gopenpgp/v2 v2.4.10/go.mod h1:CTRA7/toc/4DxDy5Du4hPDnIZnJvXSeQ8LsRTOUJoyc=
github.com/bradenaw/juniper v0.8.0 h1:sdanLNdJbLjcLj993VYIwUHlUVkLzvgiD/x9O7cvvxk=
github.com/bradenaw/juniper v0.8.0/go.mod h1:Z2B7aJlQ7xbfWsnMLROj5t/5FQ94/MkIdKC30J4WvzI=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/bwesterb/go-ristretto v1.2.1/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/cloudflare/circl v1.2.0 h1:NheeISPSUcYftKlfrLuOo4T62FkmD4t4jviLfFFYaec=
github.com/cloudflare/circl v1.2.0/go.mod h1:Ch2UgYr6ti2KTtlejELlROl0YIYj7SLjAC8M+INXlMk=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cronokirby/saferith v0.33.0 h1:TgoQlfsD4LIwx71+ChfRcIpjkw+RPOapDEVxa+LhwLo=
github.com/cronokirby/saferith v0.33.0/go.mod h1:QKJhjoqUtBsXCAVEjw38mFqoi7DebT7kthcD7UzbnoA=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emersion/go-message v0.16.0 h1:uZLz8ClLv3V5fSFF/fFdW9jXjrZkXIpE1Fn8fKx7pO4=
github.com/emersion/go-message v0.16.0/go.mod h1:pDJDgf/xeUIF+eicT6B/hPX/ZbEorKkUMPOxrPVG2eQ=
github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594 h1:IbFBtwoTQyw0fIM5xv1HF+Y+3ZijDR839WMulgxCcUY=
github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594/go.mod h1:aqO8z8wPrjkscevZJFVE1wXJrLpC5LtJG7fqLOsPb2U=
github.com/emersion/go-vcard v0.0.0-20220507122617-d4056df0ec4a h1:cltZpe6s0SJtqK5c/5y2VrIYi8BAtDM6qjmiGYqfTik=
github.com/emersion/go-vcard v0.0.0-20220507122617-d4056df0ec4a/go.mod h1:HMJKR5wlh/ziNp+sHEDV2ltblO4JD2+IdDOWtGcQBTM=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8=
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho=
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.1 h1:8e3L2cCQzLFi2CR4g7vGFuFxX7Jl1kKX8gW+iV0GUKU=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
github.com/urfave/cli/v2 v2.20.3 h1:lOgGidH/N5loaigd9HjFsOIhXSTrzl7tBpHswZ428w4=
github.com/urfave/cli/v2 v2.20.3/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA=
go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 h1:Y/gsMcFOcR+6S6f3YeMKl5g+dZMEWqcz5Czj/GWYbkM=
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20200801112145-973feb4309de/go.mod h1:skQtrUTUwhdJvXM/2KKJzY8pDgNr9I/FOMqDVRPBUS4=
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.1.1-0.20191209134235-331c550502dd/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b h1:ZmngSVLe/wycRns9MKikG9OWIEjGcGAkacif7oYQaUY=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220315194320-039c03cc5b86/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.13-0.20220804200503-81c7dc4e4efa h1:uKcci2q7Qtp6nMTC/AAvfNUAldFtJuHWV9/5QWiypts=
golang.org/x/tools v0.1.13-0.20220804200503-81c7dc4e4efa/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.50.1 h1:DS/BukOZWp8s6p4Dt/tOaJaTQyPyOoCcrjroHuCeLzY=
google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

47
header_types.go Normal file
View File

@@ -0,0 +1,47 @@
package proton
import (
"encoding/json"
"errors"
)
var ErrBadHeader = errors.New("bad header")
type Headers map[string][]string
func (h *Headers) UnmarshalJSON(b []byte) error {
type rawHeaders map[string]any
raw := make(rawHeaders)
if err := json.Unmarshal(b, &raw); err != nil {
return err
}
header := make(Headers)
for key, val := range raw {
switch val := val.(type) {
case string:
header[key] = []string{val}
case []any:
for _, val := range val {
switch val := val.(type) {
case string:
header[key] = append(header[key], val)
default:
return ErrBadHeader
}
}
default:
return ErrBadHeader
}
}
*h = header
return nil
}

54
helper_test.go Normal file
View File

@@ -0,0 +1,54 @@
package proton_test
import (
"context"
"fmt"
"runtime"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/stream"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func createTestMessages(t *testing.T, c *proton.Client, pass string, count int) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
user, err := c.GetUser(ctx)
require.NoError(t, err)
addr, err := c.GetAddresses(ctx)
require.NoError(t, err)
salt, err := c.GetSalts(ctx)
require.NoError(t, err)
keyPass, err := salt.SaltForKey([]byte(pass), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
require.NoError(t, err)
req := iterator.Collect(iterator.Map(iterator.Counter(count), func(i int) proton.ImportReq {
return proton.ImportReq{
Metadata: proton.ImportMetadata{
AddressID: addr[0].ID,
Flags: proton.MessageFlagReceived,
Unread: true,
},
Message: []byte(fmt.Sprintf("From: sender@pm.me\r\nReceiver: recipient@pm.me\r\nSubject: %v\r\n\r\nHello World!", uuid.New())),
}
}))
res, err := stream.Collect(ctx, c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), req...))
require.NoError(t, err)
for _, res := range res {
require.Equal(t, proton.SuccessCode, res.Code)
}
}

41
job.go Normal file
View File

@@ -0,0 +1,41 @@
package proton
import "context"
type job[In, Out any] struct {
ctx context.Context
req In
res chan Out
err chan error
done chan struct{}
}
func newJob[In, Out any](ctx context.Context, req In) *job[In, Out] {
return &job[In, Out]{
ctx: ctx,
req: req,
res: make(chan Out),
err: make(chan error),
done: make(chan struct{}),
}
}
func (job *job[In, Out]) result() (Out, error) {
return <-job.res, <-job.err
}
func (job *job[In, Out]) postSuccess(res Out) {
close(job.err)
job.res <- res
}
func (job *job[In, Out]) postFailure(err error) {
close(job.res)
job.err <- err
}
func (job *job[In, Out]) waitDone() {
<-job.done
}

318
keyring.go Normal file
View File

@@ -0,0 +1,318 @@
package proton
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"io"
"strings"
"github.com/ProtonMail/go-crypto/openpgp"
"github.com/ProtonMail/go-crypto/openpgp/armor"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
)
func ExtractSignatures(kr *crypto.KeyRing, arm string) ([]Signature, error) {
entities := xslices.Map(kr.GetKeys(), func(key *crypto.Key) *openpgp.Entity {
return key.GetEntity()
})
p, err := armor.Decode(strings.NewReader(arm))
if err != nil {
return nil, err
}
msg, err := openpgp.ReadMessage(p.Body, openpgp.EntityList(entities), nil, nil)
if err != nil {
return nil, err
}
if _, err := io.ReadAll(msg.UnverifiedBody); err != nil {
return nil, err
}
if !msg.IsSigned {
return nil, nil
}
var signatures []Signature
for _, signature := range msg.UnverifiedSignatures {
buf := new(bytes.Buffer)
if err := signature.Serialize(buf); err != nil {
return nil, err
}
signatures = append(signatures, Signature{
Hash: signature.Hash.String(),
Data: crypto.NewPGPSignature(buf.Bytes()),
})
}
return signatures, nil
}
type Key struct {
ID string
PrivateKey []byte
Token string
Signature string
Primary Bool
Active Bool
Flags KeyState
}
func (key *Key) UnmarshalJSON(data []byte) error {
type Alias Key
aux := &struct {
PrivateKey string
*Alias
}{
Alias: (*Alias)(key),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
privKey, err := crypto.NewKeyFromArmored(aux.PrivateKey)
if err != nil {
return err
}
raw, err := privKey.Serialize()
if err != nil {
return err
}
key.PrivateKey = raw
return nil
}
func (key Key) MarshalJSON() ([]byte, error) {
privKey, err := crypto.NewKey(key.PrivateKey)
if err != nil {
return nil, err
}
arm, err := privKey.Armor()
if err != nil {
return nil, err
}
type Alias Key
aux := &struct {
PrivateKey string
*Alias
}{
PrivateKey: arm,
Alias: (*Alias)(&key),
}
return json.Marshal(aux)
}
type Keys []Key
func (keys Keys) Primary() Key {
for _, key := range keys {
if key.Primary {
return key
}
}
panic("no primary key available")
}
func (keys Keys) ByID(keyID string) Key {
for _, key := range keys {
if key.ID == keyID {
return key
}
}
panic("no primary key available")
}
func (keys Keys) Unlock(passphrase []byte, userKR *crypto.KeyRing) (*crypto.KeyRing, error) {
kr, err := crypto.NewKeyRing(nil)
if err != nil {
return nil, err
}
for _, key := range xslices.Filter(keys, func(key Key) bool { return bool(key.Active) }) {
unlocked, err := key.Unlock(passphrase, userKR)
if err != nil {
continue
}
if err := kr.AddKey(unlocked); err != nil {
return nil, err
}
}
return kr, nil
}
func (keys Keys) TryUnlock(passphrase []byte, userKR *crypto.KeyRing) *crypto.KeyRing {
kr, err := keys.Unlock(passphrase, userKR)
if err != nil {
return nil
}
return kr
}
type PublicKey struct {
Flags KeyState
PublicKey string
}
type PublicKeys []PublicKey
func (keys PublicKeys) GetKeyRing() (*crypto.KeyRing, error) {
kr, err := crypto.NewKeyRing(nil)
if err != nil {
return nil, err
}
for _, key := range keys {
pubKey, err := crypto.NewKeyFromArmored(key.PublicKey)
if err != nil {
return nil, err
}
if err := kr.AddKey(pubKey); err != nil {
return nil, err
}
}
return kr, nil
}
type KeyList struct {
Data string
Signature string
}
func NewKeyList(signer *crypto.KeyRing, entries []KeyListEntry) (KeyList, error) {
data, err := json.Marshal(entries)
if err != nil {
return KeyList{}, err
}
sig, err := signer.SignDetached(crypto.NewPlainMessage(data))
if err != nil {
return KeyList{}, err
}
arm, err := sig.GetArmored()
if err != nil {
return KeyList{}, err
}
return KeyList{
Data: string(data),
Signature: arm,
}, nil
}
type KeyListEntry struct {
Fingerprint string
SHA256Fingerprints []string
Flags KeyState
Primary Bool
}
type KeyState int
const (
KeyStateTrusted KeyState = 1 << iota // 2^0 = 1 means the key is not compromised (i.e. if we can trust signatures coming from it)
KeyStateActive // 2^1 = 2 means the key is still in use (i.e. not obsolete, we can encrypt messages to it)
)
func (key Key) Unlock(passphrase []byte, userKR *crypto.KeyRing) (*crypto.Key, error) {
var secret []byte
if key.Token == "" || key.Signature == "" {
secret = passphrase
} else {
var err error
if secret, err = key.getPassphraseFromToken(userKR); err != nil {
return nil, err
}
}
return key.unlock(secret)
}
func (key Key) getPassphraseFromToken(kr *crypto.KeyRing) ([]byte, error) {
if kr == nil {
return nil, errors.New("no user key was provided")
}
msg, err := crypto.NewPGPMessageFromArmored(key.Token)
if err != nil {
return nil, err
}
sig, err := crypto.NewPGPSignatureFromArmored(key.Signature)
if err != nil {
return nil, err
}
token, err := kr.Decrypt(msg, nil, 0)
if err != nil {
return nil, err
}
if err = kr.VerifyDetached(token, sig, 0); err != nil {
return nil, err
}
return token.GetBinary(), nil
}
func (key Key) unlock(passphrase []byte) (*crypto.Key, error) {
lk, err := crypto.NewKey(key.PrivateKey)
if err != nil {
return nil, err
}
defer lk.ClearPrivateParams()
uk, err := lk.Unlock(passphrase)
if err != nil {
return nil, err
}
ok, err := uk.Check()
if err != nil {
return nil, err
} else if !ok {
return nil, errors.New("private and public keys do not match")
}
return uk, nil
}
func DecodeKeyPacket(packet string) []byte {
if packet == "" {
return nil
}
raw, err := base64.StdEncoding.DecodeString(packet)
if err != nil {
panic(err)
}
return raw
}

62
keys.go Normal file
View File

@@ -0,0 +1,62 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetPublicKeys(ctx context.Context, address string) (PublicKeys, RecipientType, error) {
var res struct {
Keys []PublicKey
RecipientType RecipientType
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).SetQueryParam("Email", address).Get("/core/v4/keys")
}); err != nil {
return nil, RecipientTypeExternal, err
}
return res.Keys, res.RecipientType, nil
}
func (c *Client) CreateAddressKey(ctx context.Context, req CreateAddressKeyReq) (Key, error) {
var res struct {
Key Key
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/core/v4/keys/address")
}); err != nil {
return Key{}, err
}
return res.Key, nil
}
func (c *Client) CreateLegacyAddressKey(ctx context.Context, req CreateAddressKeyReq) (Key, error) {
var res struct {
Key Key
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/core/v4/keys")
}); err != nil {
return Key{}, err
}
return res.Key, nil
}
func (c *Client) MakeAddressKeyPrimary(ctx context.Context, keyID string, keyList KeyList) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(struct{ SignedKeyList KeyList }{SignedKeyList: keyList}).Put("/core/v4/keys/" + keyID + "/primary")
})
}
func (c *Client) DeleteAddressKey(ctx context.Context, keyID string, keyList KeyList) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(struct{ SignedKeyList KeyList }{SignedKeyList: keyList}).Put("/core/v4/keys/" + keyID + "/delete")
})
}

16
keys_types.go Normal file
View File

@@ -0,0 +1,16 @@
package proton
type CreateAddressKeyReq struct {
AddressID string
PrivateKey string
Primary Bool
SignedKeyList KeyList
// The following are only used in "migrated accounts"
Token string `json:",omitempty"`
Signature string `json:",omitempty"`
}
type MakeAddressKeyPrimaryReq struct {
SignedKeyList KeyList
}

82
label.go Normal file
View File

@@ -0,0 +1,82 @@
package proton
import (
"context"
"errors"
"strconv"
"github.com/go-resty/resty/v2"
)
var ErrNoSuchLabel = errors.New("no such label")
func (c *Client) GetLabel(ctx context.Context, labelID string, labelTypes ...LabelType) (Label, error) {
labels, err := c.GetLabels(ctx, labelTypes...)
if err != nil {
return Label{}, err
}
for _, label := range labels {
if label.ID == labelID {
return label, nil
}
}
return Label{}, ErrNoSuchLabel
}
func (c *Client) GetLabels(ctx context.Context, labelTypes ...LabelType) ([]Label, error) {
var labels []Label
for _, labelType := range labelTypes {
labelType := labelType
var res struct {
Labels []Label
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParam("Type", strconv.Itoa(int(labelType))).SetResult(&res).Get("/core/v4/labels")
}); err != nil {
return nil, err
}
labels = append(labels, res.Labels...)
}
return labels, nil
}
func (c *Client) CreateLabel(ctx context.Context, req CreateLabelReq) (Label, error) {
var res struct {
Label Label
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/core/v4/labels")
}); err != nil {
return Label{}, err
}
return res.Label, nil
}
func (c *Client) DeleteLabel(ctx context.Context, labelID string) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.Delete("/core/v4/labels/" + labelID)
})
}
func (c *Client) UpdateLabel(ctx context.Context, labelID string, req UpdateLabelReq) (Label, error) {
var res struct {
Label Label
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/core/v4/labels/" + labelID)
}); err != nil {
return Label{}, err
}
return res.Label, nil
}

87
label_types.go Normal file
View File

@@ -0,0 +1,87 @@
package proton
import (
"encoding/json"
"strings"
)
const (
InboxLabel = "0"
AllDraftsLabel = "1"
AllSentLabel = "2"
TrashLabel = "3"
SpamLabel = "4"
AllMailLabel = "5"
ArchiveLabel = "6"
SentLabel = "7"
DraftsLabel = "8"
OutboxLabel = "9"
StarredLabel = "10"
)
type Label struct {
ID string
Name string
Path []string
Color string
Type LabelType
}
func (label *Label) UnmarshalJSON(data []byte) error {
type Alias Label
aux := &struct {
Path string
*Alias
}{
Alias: (*Alias)(label),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
label.Path = strings.Split(aux.Path, "/")
return nil
}
func (label Label) MarshalJSON() ([]byte, error) {
type Alias Label
aux := &struct {
Path string
*Alias
}{
Path: strings.Join(label.Path, "/"),
Alias: (*Alias)(&label),
}
return json.Marshal(aux)
}
type CreateLabelReq struct {
Name string
Color string
Type LabelType
ParentID string `json:",omitempty"`
}
type UpdateLabelReq struct {
Name string
Color string
ParentID string `json:",omitempty"`
}
type LabelType int
const (
LabelTypeLabel LabelType = iota + 1
LabelTypeContactGroup
LabelTypeFolder
LabelTypeSystem
)

108
link.go Normal file
View File

@@ -0,0 +1,108 @@
package proton
import (
"context"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetLink(ctx context.Context, shareID, linkID string) (Link, error) {
var res struct {
Link Link
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/drive/shares/" + shareID + "/links/" + linkID)
}); err != nil {
return Link{}, err
}
return res.Link, nil
}
func (c *Client) ListChildren(ctx context.Context, shareID, linkID string) ([]Link, error) {
var res struct {
Links []Link
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/drive/shares/" + shareID + "/folders/" + linkID + "/children")
}); err != nil {
return nil, err
}
return res.Links, nil
}
func (c *Client) ListRevisions(ctx context.Context, shareID, linkID string) ([]Revision, error) {
var res struct {
Revisions []Revision
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/drive/shares/" + shareID + "/files/" + linkID + "/revisions")
}); err != nil {
return nil, err
}
return res.Revisions, nil
}
func (c *Client) GetRevision(ctx context.Context, shareID, linkID, revisionID string) (Revision, error) {
var res struct {
Revision Revision
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/drive/shares/" + shareID + "/files/" + linkID + "/revisions/" + revisionID)
}); err != nil {
return Revision{}, err
}
return res.Revision, nil
}
func (c *Client) VisitLink(ctx context.Context, shareID string, link Link, kr *crypto.KeyRing, fn LinkWalkFunc) error {
return c.visitLink(ctx, shareID, link, kr, fn, []string{})
}
func (c *Client) visitLink(ctx context.Context, shareID string, link Link, kr *crypto.KeyRing, fn LinkWalkFunc, path []string) error {
enc, err := crypto.NewPGPMessageFromArmored(link.Name)
if err != nil {
return err
}
dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime())
if err != nil {
return err
}
path = append(path, dec.GetString())
childKR, err := link.GetKeyRing(kr)
if err != nil {
return err
}
if err := fn(path, link, childKR); err != nil {
return err
}
if link.Type != FolderLinkType {
return nil
}
children, err := c.ListChildren(ctx, shareID, link.LinkID)
if err != nil {
return err
}
for _, child := range children {
if err := c.visitLink(ctx, shareID, child, childKR, fn, path); err != nil {
return err
}
}
return nil
}

5
link_types.go Normal file
View File

@@ -0,0 +1,5 @@
package proton
import "github.com/ProtonMail/gopenpgp/v2/crypto"
type LinkWalkFunc func([]string, Link, *crypto.KeyRing) error

105
mail_settings.go Normal file
View File

@@ -0,0 +1,105 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetMailSettings(ctx context.Context) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/mail/v4/settings")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}
func (c *Client) SetDisplayName(ctx context.Context, req SetDisplayNameReq) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/mail/v4/settings/display")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}
func (c *Client) SetSignature(ctx context.Context, req SetSignatureReq) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/mail/v4/settings/signature")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}
func (c *Client) SetDraftMIMEType(ctx context.Context, req SetDraftMIMETypeReq) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/mail/v4/settings/drafttype")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}
func (c *Client) SetAttachPublicKey(ctx context.Context, req SetAttachPublicKeyReq) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/mail/v4/settings/attachpublic")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}
func (c *Client) SetSignExternalMessages(ctx context.Context, req SetSignExternalMessagesReq) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/mail/v4/settings/sign")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}
func (c *Client) SetDefaultPGPScheme(ctx context.Context, req SetDefaultPGPSchemeReq) (MailSettings, error) {
var res struct {
MailSettings MailSettings
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Put("/mail/v4/settings/pgpscheme")
}); err != nil {
return MailSettings{}, err
}
return res.MailSettings, nil
}

50
mail_settings_types.go Normal file
View File

@@ -0,0 +1,50 @@
package proton
import "github.com/ProtonMail/gluon/rfc822"
type MailSettings struct {
DisplayName string
Signature string
DraftMIMEType rfc822.MIMEType
AttachPublicKey AttachPublicKey
Sign SignExternalMessages
PGPScheme EncryptionScheme
}
type AttachPublicKey int
const (
AttachPublicKeyDisabled AttachPublicKey = iota
AttachPublicKeyEnabled
)
type SignExternalMessages int
const (
SignExternalMessagesDisabled SignExternalMessages = iota
SignExternalMessagesEnabled
)
type SetDisplayNameReq struct {
DisplayName string
}
type SetSignatureReq struct {
Signature string
}
type SetDraftMIMETypeReq struct {
MIMEType rfc822.MIMEType
}
type SetAttachPublicKeyReq struct {
AttachPublicKey AttachPublicKey
}
type SetSignExternalMessagesReq struct {
Sign SignExternalMessages
}
type SetDefaultPGPSchemeReq struct {
PGPScheme EncryptionScheme
}

11
main_test.go Normal file
View File

@@ -0,0 +1,11 @@
package proton
import (
"testing"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.IgnoreCurrent())
}

125
manager.go Normal file
View File

@@ -0,0 +1,125 @@
package proton
import (
"context"
"errors"
"sync"
"github.com/go-resty/resty/v2"
)
type Manager struct {
rc *resty.Client
status Status
observers []StatusObserver
statusLock sync.Mutex
errHandlers map[Code][]Handler
attPoolSize int
verifyProofs bool
}
func New(opts ...Option) *Manager {
builder := newManagerBuilder()
for _, opt := range opts {
opt.config(builder)
}
return builder.build()
}
func (m *Manager) AddStatusObserver(observer StatusObserver) {
m.statusLock.Lock()
defer m.statusLock.Unlock()
m.observers = append(m.observers, observer)
}
func (m *Manager) AddPreRequestHook(hook resty.RequestMiddleware) {
m.rc.OnBeforeRequest(hook)
}
func (m *Manager) AddPostRequestHook(hook resty.ResponseMiddleware) {
m.rc.OnAfterResponse(hook)
}
func (m *Manager) AddErrorHandler(code Code, handler Handler) {
m.errHandlers[code] = append(m.errHandlers[code], handler)
}
func (m *Manager) Close() {
m.rc.GetClient().CloseIdleConnections()
}
func (m *Manager) r(ctx context.Context) *resty.Request {
return m.rc.R().SetContext(ctx)
}
func (m *Manager) handleError(req *resty.Request, err error) {
resErr, ok := err.(*resty.ResponseError)
if !ok {
return
}
apiErr, ok := resErr.Response.Error().(*Error)
if !ok {
return
}
for _, handler := range m.errHandlers[apiErr.Code] {
handler()
}
}
func (m *Manager) checkConnUp(_ *resty.Client, res *resty.Response) error {
m.onConnUp()
return nil
}
func (m *Manager) checkConnDown(req *resty.Request, err error) {
switch {
case errors.Is(err, context.Canceled):
return
}
if res, ok := err.(*resty.ResponseError); ok && res.Response.RawResponse != nil {
m.onConnUp()
} else {
m.onConnDown()
}
}
func (m *Manager) onConnDown() {
m.statusLock.Lock()
defer m.statusLock.Unlock()
if m.status == StatusDown {
return
}
m.status = StatusDown
for _, observer := range m.observers {
observer(m.status)
}
}
func (m *Manager) onConnUp() {
m.statusLock.Lock()
defer m.statusLock.Unlock()
if m.status == StatusUp {
return
}
m.status = StatusUp
for _, observer := range m.observers {
observer(m.status)
}
}

123
manager_auth.go Normal file
View File

@@ -0,0 +1,123 @@
package proton
import (
"bytes"
"context"
"encoding/base64"
"errors"
"time"
"github.com/ProtonMail/go-srp"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
var ErrInvalidProof = errors.New("unexpected server proof")
func (m *Manager) NewClient(uid, acc, ref string, exp time.Time) *Client {
return newClient(m, uid).withAuth(acc, ref, exp)
}
func (m *Manager) NewClientWithRefresh(ctx context.Context, uid, ref string) (*Client, Auth, error) {
c := newClient(m, uid)
auth, err := m.authRefresh(ctx, uid, ref)
if err != nil {
return nil, Auth{}, err
}
return c.withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
}
func (m *Manager) NewClientWithLogin(ctx context.Context, username string, password []byte) (*Client, Auth, error) {
info, err := m.getAuthInfo(ctx, AuthInfoReq{Username: username})
if err != nil {
return nil, Auth{}, err
}
srpAuth, err := srp.NewAuth(info.Version, username, password, info.Salt, info.Modulus, info.ServerEphemeral)
if err != nil {
return nil, Auth{}, err
}
proofs, err := srpAuth.GenerateProofs(2048)
if err != nil {
return nil, Auth{}, err
}
auth, err := m.auth(ctx, AuthReq{
Username: username,
ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof),
ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral),
SRPSession: info.SRPSession,
})
if err != nil {
return nil, Auth{}, err
}
serverProof, err := base64.StdEncoding.DecodeString(auth.ServerProof)
if err != nil {
return nil, Auth{}, err
}
if m.verifyProofs {
if !bytes.Equal(serverProof, proofs.ExpectedServerProof) {
return nil, Auth{}, ErrInvalidProof
}
}
return newClient(m, auth.UID).withAuth(auth.AccessToken, auth.RefreshToken, expiresIn(auth.ExpiresIn)), auth, nil
}
func (m *Manager) getAuthInfo(ctx context.Context, req AuthInfoReq) (AuthInfo, error) {
var res struct {
AuthInfo
}
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/core/v4/auth/info"); err != nil {
return AuthInfo{}, err
}
return res.AuthInfo, nil
}
func (m *Manager) auth(ctx context.Context, req AuthReq) (Auth, error) {
var res struct {
Auth
}
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/core/v4/auth"); err != nil {
return Auth{}, err
}
return res.Auth, nil
}
func (m *Manager) authRefresh(ctx context.Context, uid, ref string) (Auth, error) {
state, err := crypto.RandomToken(32)
if err != nil {
return Auth{}, err
}
req := AuthRefreshReq{
UID: uid,
RefreshToken: ref,
ResponseType: "token",
GrantType: "refresh_token",
RedirectURI: "https://protonmail.ch",
State: string(state),
}
var res struct {
Auth
}
if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/core/v4/auth/refresh"); err != nil {
return Auth{}, err
}
return res.Auth, nil
}
func expiresIn(seconds int) time.Time {
return time.Now().Add(time.Duration(seconds) * time.Second)
}

98
manager_builder.go Normal file
View File

@@ -0,0 +1,98 @@
package proton
import (
"net/http"
"runtime"
"time"
"github.com/go-resty/resty/v2"
)
const (
// DefaultHostURL is the default host of the API.
DefaultHostURL = "https://mail.proton.me/api"
// DefaultAppVersion is the default app version used to communicate with the API.
// This must be changed (using the WithAppVersion option) for production use.
DefaultAppVersion = "go-proton-api"
)
type managerBuilder struct {
hostURL string
appVersion string
transport http.RoundTripper
attPoolSize int
verifyProofs bool
cookieJar http.CookieJar
retryCount int
logger resty.Logger
debug bool
}
func newManagerBuilder() *managerBuilder {
return &managerBuilder{
hostURL: DefaultHostURL,
appVersion: DefaultAppVersion,
transport: http.DefaultTransport,
attPoolSize: runtime.NumCPU(),
verifyProofs: true,
cookieJar: nil,
retryCount: 3,
logger: nil,
debug: false,
}
}
func (builder *managerBuilder) build() *Manager {
m := &Manager{
rc: resty.New(),
errHandlers: make(map[Code][]Handler),
attPoolSize: builder.attPoolSize,
verifyProofs: builder.verifyProofs,
}
// Set the API host.
m.rc.SetBaseURL(builder.hostURL)
// Set the transport.
m.rc.SetTransport(builder.transport)
// Set the cookie jar.
m.rc.SetCookieJar(builder.cookieJar)
// Set the logger.
if builder.logger != nil {
m.rc.SetLogger(builder.logger)
}
// Set the debug flag.
m.rc.SetDebug(builder.debug)
// Set app version in header.
m.rc.OnBeforeRequest(func(_ *resty.Client, req *resty.Request) error {
req.SetHeader("x-pm-appversion", builder.appVersion)
return nil
})
// Set middleware.
m.rc.OnAfterResponse(catchAPIError)
m.rc.OnAfterResponse(updateTime)
m.rc.OnAfterResponse(m.checkConnUp)
m.rc.OnError(m.checkConnDown)
m.rc.OnError(m.handleError)
// Configure retry mechanism.
m.rc.SetRetryCount(builder.retryCount)
m.rc.SetRetryMaxWaitTime(time.Minute)
m.rc.AddRetryCondition(catchTooManyRequests)
m.rc.AddRetryCondition(catchDialError)
m.rc.SetRetryAfter(catchRetryAfter)
// Set the data type of API errors.
m.rc.SetError(&Error{})
return m
}

48
manager_download.go Normal file
View File

@@ -0,0 +1,48 @@
package proton
import (
"context"
"io"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
func (m *Manager) DownloadAndVerify(ctx context.Context, kr *crypto.KeyRing, url, sig string) ([]byte, error) {
fb, err := m.fetchFile(ctx, url)
if err != nil {
return nil, err
}
sb, err := m.fetchFile(ctx, sig)
if err != nil {
return nil, err
}
if err := kr.VerifyDetached(
crypto.NewPlainMessage(fb),
crypto.NewPGPSignature(sb),
crypto.GetUnixTime(),
); err != nil {
return nil, err
}
return fb, nil
}
func (m *Manager) fetchFile(ctx context.Context, url string) ([]byte, error) {
res, err := m.r(ctx).SetDoNotParseResponse(true).Get(url)
if err != nil {
return nil, err
}
b, err := io.ReadAll(res.RawBody())
if err != nil {
return nil, err
}
if err := res.RawBody().Close(); err != nil {
return nil, err
}
return b, nil
}

15
manager_ping.go Normal file
View File

@@ -0,0 +1,15 @@
package proton
import "context"
func (m *Manager) Ping(ctx context.Context) error {
if res, err := m.r(ctx).Get("/tests/ping"); err != nil {
if res.RawResponse != nil {
return nil
}
return err
}
return nil
}

20
manager_report.go Normal file
View File

@@ -0,0 +1,20 @@
package proton
import (
"bytes"
"context"
)
func (m *Manager) ReportBug(ctx context.Context, req ReportBugReq, atts ...ReportBugAttachment) error {
r := m.r(ctx).SetMultipartFormData(req.toFormData())
for _, att := range atts {
r = r.SetMultipartField(att.Name, att.Filename, string(att.MIMEType), bytes.NewReader(att.Body))
}
if _, err := r.Post("/core/v4/reports/bug"); err != nil {
return err
}
return nil
}

50
manager_report_test.go Normal file
View File

@@ -0,0 +1,50 @@
package proton_test
import (
"bytes"
"context"
"mime"
"mime/multipart"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/stretchr/testify/require"
)
func TestReportBug(t *testing.T) {
s := server.New()
defer s.Close()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
var calls []server.Call
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
require.NoError(t, m.ReportBug(context.Background(), proton.ReportBugReq{
OS: "linux",
OSVersion: "5.4.0-42-generic",
Browser: "firefox",
ClientType: proton.ClientTypeEmail,
}))
mimeType, mimeParams, err := mime.ParseMediaType(calls[0].RequestHeader.Get("Content-Type"))
require.NoError(t, err)
require.Equal(t, "multipart/form-data", mimeType)
form, err := multipart.NewReader(bytes.NewReader(calls[0].RequestBody), mimeParams["boundary"]).ReadForm(0)
require.NoError(t, err)
require.Len(t, form.Value, 4)
require.Equal(t, "linux", form.Value["OS"][0])
require.Equal(t, "5.4.0-42-generic", form.Value["OSVersion"][0])
require.Equal(t, "firefox", form.Value["Browser"][0])
require.Equal(t, "1", form.Value["ClientType"][0])
}

72
manager_report_types.go Normal file
View File

@@ -0,0 +1,72 @@
package proton
import (
"encoding/json"
"fmt"
"github.com/ProtonMail/gluon/rfc822"
)
type ClientType int
const (
ClientTypeEmail ClientType = iota + 1
ClientTypeVPN
ClientTypeCalendar
ClientTypeDrive
)
type ReportBugReq struct {
OS string
OSVersion string
Browser string
BrowserVersion string
BrowserExtensions string
Resolution string
DisplayMode string
Client string
ClientVersion string
ClientType ClientType
Title string
Description string
Username string
Email string
Country string
ISP string
}
func (req ReportBugReq) toFormData() map[string]string {
b, err := json.Marshal(req)
if err != nil {
panic(err)
}
var raw map[string]any
if err := json.Unmarshal(b, &raw); err != nil {
panic(err)
}
res := make(map[string]string)
for key := range raw {
if val := fmt.Sprint(raw[key]); val != "" {
res[key] = val
}
}
return res
}
type ReportBugAttachment struct {
Name string
Filename string
MIMEType rfc822.MIMEType
Body []byte
}

23
manager_status.go Normal file
View File

@@ -0,0 +1,23 @@
package proton
type Status int
const (
StatusUp Status = iota
StatusDown
)
func (s Status) String() string {
switch s {
case StatusUp:
return "up"
case StatusDown:
return "down"
default:
return "unknown"
}
}
type StatusObserver func(Status)

283
manager_status_test.go Normal file
View File

@@ -0,0 +1,283 @@
package proton_test
import (
"context"
"crypto/tls"
"net"
"testing"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/stretchr/testify/require"
)
func TestStatus(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
var (
called int
status proton.Status
)
m.AddStatusObserver(func(val proton.Status) {
called++
status = val
})
// This should succeed.
require.NoError(t, m.Ping(context.Background()))
// Status should not have been called yet.
require.Zero(t, called)
// Now we simulate a network failure.
netCtl.Disable()
// This should fail.
require.Error(t, m.Ping(context.Background()))
// Status should have been called once and status should indicate network is down.
require.Equal(t, 1, called)
require.Equal(t, proton.StatusDown, status)
// Now we simulate a network restoration.
netCtl.Enable()
// This should succeed.
require.NoError(t, m.Ping(context.Background()))
// Status should have been called twice and status should indicate network is up.
require.Equal(t, 2, called)
require.Equal(t, proton.StatusUp, status)
}
func TestStatus_NoDial(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
var (
called int
status proton.Status
)
m.AddStatusObserver(func(val proton.Status) {
called++
status = val
})
// Disable dialing.
netCtl.SetCanDial(false)
// This should fail.
require.Error(t, m.Ping(context.Background()))
// Status should have been called once and status should indicate network is down.
require.Equal(t, 1, called)
require.Equal(t, proton.StatusDown, status)
}
func TestStatus_NoRead(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
var (
called int
status proton.Status
)
m.AddStatusObserver(func(val proton.Status) {
called++
status = val
})
// Disable reading.
netCtl.SetCanRead(false)
// This should fail.
require.Error(t, m.Ping(context.Background()))
// Status should have been called once and status should indicate network is down.
require.Equal(t, 1, called)
require.Equal(t, proton.StatusDown, status)
}
func TestStatus_NoWrite(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
var (
called int
status proton.Status
)
m.AddStatusObserver(func(val proton.Status) {
called++
status = val
})
// Disable writing.
netCtl.SetCanWrite(false)
// This should fail.
require.Error(t, m.Ping(context.Background()))
// Status should have been called once and status should indicate network is down.
require.Equal(t, 1, called)
require.Equal(t, proton.StatusDown, status)
}
func TestStatus_NoReadExistingConn(t *testing.T) {
s := server.New()
defer s.Close()
_, _, err := s.CreateUser("user", "user@pm.me", []byte("pass"))
require.NoError(t, err)
netCtl := proton.NewNetCtl()
var dialed int
netCtl.OnDial(func(net.Conn) {
dialed++
})
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
// This should succeed.
c, _, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
defer c.Close()
// We should have dialed once.
require.Equal(t, 1, dialed)
// Disable reading on the existing connection.
netCtl.SetCanRead(false)
// This should fail because we won't be able to read the response.
require.Error(t, getErr(c.GetUser(context.Background())))
// We should still have dialed once; the connection should have been reused.
require.Equal(t, 1, dialed)
}
func TestStatus_NoWriteExistingConn(t *testing.T) {
s := server.New()
defer s.Close()
_, _, err := s.CreateUser("user", "user@pm.me", []byte("pass"))
require.NoError(t, err)
netCtl := proton.NewNetCtl()
var dialed int
netCtl.OnDial(func(net.Conn) {
dialed++
})
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithRetryCount(0),
)
// This should succeed.
c, _, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
defer c.Close()
// We should have dialed once.
require.Equal(t, 1, dialed)
// Disable reading on the existing connection.
netCtl.SetCanWrite(false)
// This should fail because we won't be able to write the request.
require.Error(t, c.LabelMessages(context.Background(), []string{"messageID"}, proton.TrashLabel))
// We should still have dialed twice; the connection could not be reused because the write failed.
require.Equal(t, 2, dialed)
}
func TestStatus_ContextCancel(t *testing.T) {
s := server.New()
defer s.Close()
m := proton.New(proton.WithHostURL(s.GetHostURL()))
var called int
m.AddStatusObserver(func(val proton.Status) {
called++
})
// Create a context that will be canceled.
ctx, cancel := context.WithCancel(context.Background())
cancel()
// This should fail because the context is canceled.
require.Error(t, m.Ping(ctx))
// Status should not have been called; this was not a network error.
require.Zero(t, called)
}
func TestStatus_ContextTimeout(t *testing.T) {
s := server.New()
defer s.Close()
m := proton.New(proton.WithHostURL(s.GetHostURL()))
var called int
m.AddStatusObserver(func(val proton.Status) {
called++
})
// Create a context that will time out.
ctx, cancel := context.WithTimeout(context.Background(), 0)
cancel()
// This should fail because the context is canceled.
require.Error(t, m.Ping(ctx))
// Status should have been called; this was a network error (took too long).
require.NotZero(t, called)
}
func getErr[T any](_ T, err error) error {
return err
}

314
manager_test.go Normal file
View File

@@ -0,0 +1,314 @@
package proton_test
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/stretchr/testify/require"
)
func TestConnectionReuse(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
var dialed int
netCtl.OnDial(func(net.Conn) {
dialed++
})
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
// This should succeed; the resulting connection should be reused.
require.NoError(t, m.Ping(context.Background()))
// We should have dialed once.
require.Equal(t, 1, dialed)
// This should succeed; we should not re-dial.
require.NoError(t, m.Ping(context.Background()))
// We should not have re-dialed.
require.Equal(t, 1, dialed)
}
func TestAuthRefresh(t *testing.T) {
s := server.New()
defer s.Close()
_, _, err := s.CreateUser("user", "email@pm.me", []byte("pass"))
require.NoError(t, err)
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.InsecureTransport()),
)
c1, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
defer c1.Close()
c2, auth, err := m.NewClientWithRefresh(context.Background(), auth.UID, auth.RefreshToken)
require.NoError(t, err)
defer c2.Close()
}
func TestHandleTooManyRequests(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
if numCalls < 5 {
w.WriteHeader(http.StatusTooManyRequests)
} else {
w.WriteHeader(http.StatusOK)
}
}))
defer ts.Close()
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
)
// The call should succeed because the 5th retry should succeed (429s are retried).
c := m.NewClient("", "", "", time.Now().Add(time.Hour))
defer c.Close()
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
}
// The server should be called 5 times.
// The first four calls should return 429 and the last call should return 200.
if numCalls != 5 {
t.Fatal("expected numCalls to be 5, instead got", numCalls)
}
}
func TestHandleUnprocessableEntity(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
w.WriteHeader(http.StatusUnprocessableEntity)
}))
defer ts.Close()
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
)
// The call should fail because the first call should fail (422s are not retried).
c := m.NewClient("", "", "", time.Now().Add(time.Hour))
defer c.Close()
if _, err := c.GetAddresses(context.Background()); err == nil {
t.Fatal("expected error, instead got", err)
}
// The server should be called 1 time.
// The first call should return 422.
if numCalls != 1 {
t.Fatal("expected numCalls to be 1, instead got", numCalls)
}
}
func TestHandleDialFailure(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
proton.WithTransport(newFailingRoundTripper(5)),
)
// The call should succeed because the last retry should succeed (dial errors are retried).
c := m.NewClient("", "", "", time.Now().Add(time.Hour))
defer c.Close()
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
}
// The server should be called 1 time.
// The first 4 attempts don't reach the server.
if numCalls != 1 {
t.Fatal("expected numCalls to be 1, instead got", numCalls)
}
}
func TestHandleTooManyDialFailures(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
// The failingRoundTripper will fail the first 10 times it is used.
// This is more than the number of retries we permit.
// Thus, dials will fail.
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
proton.WithTransport(newFailingRoundTripper(10)),
)
// The call should fail because every dial will fail and we'll run out of retries.
c := m.NewClient("", "", "", time.Now().Add(time.Hour))
defer c.Close()
if _, err := c.GetAddresses(context.Background()); err == nil {
t.Fatal("expected error, instead got", err)
}
// The server should never be called.
if numCalls != 0 {
t.Fatal("expected numCalls to be 0, instead got", numCalls)
}
}
func TestRetriesWithContextTimeout(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
if numCalls < 5 {
w.WriteHeader(http.StatusTooManyRequests)
} else {
w.WriteHeader(http.StatusOK)
}
time.Sleep(time.Second)
}))
defer ts.Close()
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
)
// Timeout after 1s.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// Theoretically, this should succeed; on the fifth retry, we'll get StatusOK.
// However, that will take at least >5s, and we only allow 1s in the context.
// Thus, it will fail.
c := m.NewClient("", "", "", time.Now().Add(time.Hour))
defer c.Close()
if _, err := c.GetAddresses(ctx); err == nil {
t.Fatal("expected error, instead got", err)
}
}
func TestReturnErrNoConnection(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
// We will fail more times than we retry, so requests should fail with ErrNoConnection.
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
proton.WithTransport(newFailingRoundTripper(10)),
)
// The call should fail because every dial will fail and we'll run out of retries.
c := m.NewClient("", "", "", time.Now().Add(time.Hour))
defer c.Close()
if _, err := c.GetAddresses(context.Background()); err == nil {
t.Fatal("expected error, instead got", err)
}
}
func TestStatusCallbacks(t *testing.T) {
s := server.New()
defer s.Close()
ctl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
)
statusCh := make(chan proton.Status, 1)
m.AddStatusObserver(func(status proton.Status) {
statusCh <- status
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctl.Disable()
require.Error(t, m.Ping(ctx))
require.Equal(t, proton.StatusDown, <-statusCh)
ctl.Enable()
require.NoError(t, m.Ping(ctx))
require.Equal(t, proton.StatusUp, <-statusCh)
ctl.SetReadLimit(1)
require.Error(t, m.Ping(ctx))
require.Equal(t, proton.StatusDown, <-statusCh)
ctl.SetReadLimit(0)
require.NoError(t, m.Ping(ctx))
require.Equal(t, proton.StatusUp, <-statusCh)
}
type failingRoundTripper struct {
http.RoundTripper
fails, calls int
}
func newFailingRoundTripper(fails int) http.RoundTripper {
return &failingRoundTripper{
RoundTripper: http.DefaultTransport,
fails: fails,
}
}
func (rt *failingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.calls++
if rt.calls < rt.fails {
return nil, errors.New("simulating network error")
}
return rt.RoundTripper.RoundTrip(req)
}

274
message.go Normal file
View File

@@ -0,0 +1,274 @@
package proton
import (
"context"
"fmt"
"runtime"
"strconv"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/go-resty/resty/v2"
)
const maxMessageIDs = 1000
func (c *Client) GetFullMessage(ctx context.Context, messageID string) (FullMessage, error) {
message, err := c.GetMessage(ctx, messageID)
if err != nil {
return FullMessage{}, err
}
attData, err := c.attPool().ProcessAll(ctx, xslices.Map(message.Attachments, func(att Attachment) string {
return att.ID
}))
if err != nil {
return FullMessage{}, err
}
return FullMessage{
Message: message,
AttData: attData,
}, nil
}
func (c *Client) GetFullMessages(ctx context.Context, workers, buffer int, messageIDs ...string) stream.Stream[FullMessage] {
return parallel.MapStream(
ctx,
stream.FromIterator(iterator.Slice(messageIDs)),
workers,
buffer,
c.GetFullMessage,
)
}
func (c *Client) GetMessage(ctx context.Context, messageID string) (Message, error) {
var res struct {
Message Message
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/mail/v4/messages/" + messageID)
}); err != nil {
return Message{}, err
}
return res.Message, nil
}
func (c *Client) CountMessages(ctx context.Context) (int, error) {
var res struct {
Total int
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetQueryParams(map[string]string{
"Limit": strconv.Itoa(0),
}).SetResult(&res).Get("/mail/v4/messages")
}); err != nil {
return 0, err
}
return res.Total, nil
}
func (c *Client) GetMessageMetadata(ctx context.Context, filter MessageFilter) ([]MessageMetadata, error) {
total, err := c.CountMessages(ctx)
if err != nil {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]MessageMetadata, error) {
return c.getMessageMetadata(ctx, page, pageSize, filter)
})
}
func (c *Client) GetMessageIDs(ctx context.Context, afterID string) ([]string, error) {
var messageIDs []string
for ; ; afterID = messageIDs[len(messageIDs)-1] {
page, err := c.getMessageIDs(ctx, afterID)
if err != nil {
return nil, err
}
if len(page) == 0 {
return messageIDs, nil
}
messageIDs = append(messageIDs, page...)
}
}
func (c *Client) DeleteMessage(ctx context.Context, messageIDs ...string) error {
pages := xslices.Chunk(messageIDs, maxPageSize)
return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/delete")
})
})
}
func (c *Client) MarkMessagesRead(ctx context.Context, messageIDs ...string) error {
pages := xslices.Chunk(messageIDs, maxPageSize)
return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error {
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/read")
})
})
}
func (c *Client) MarkMessagesUnread(ctx context.Context, messageIDs ...string) error {
pages := xslices.Chunk(messageIDs, maxPageSize)
return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error {
req := MessageActionReq{IDs: pages[idx]}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).Put("/mail/v4/messages/unread")
}); err != nil {
return err
}
return nil
})
}
func (c *Client) LabelMessages(ctx context.Context, messageIDs []string, labelID string) error {
res, err := parallel.MapContext(
ctx,
runtime.NumCPU(),
xslices.Chunk(messageIDs, maxPageSize),
func(ctx context.Context, messageIDs []string) (LabelMessagesRes, error) {
var res LabelMessagesRes
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(LabelMessagesReq{
LabelID: labelID,
IDs: messageIDs,
}).SetResult(&res).Put("/mail/v4/messages/label")
}); err != nil {
return LabelMessagesRes{}, err
}
return res, nil
},
)
if err != nil {
return err
}
if idx := xslices.IndexFunc(res, func(res LabelMessagesRes) bool { return !res.ok() }); idx >= 0 {
tokens := xslices.Map(res, func(res LabelMessagesRes) UndoToken {
return res.UndoToken
})
if _, undoErr := c.UndoActions(ctx, tokens...); undoErr != nil {
return fmt.Errorf("failed to undo actions: %w", undoErr)
}
return fmt.Errorf("failed to label messages")
}
return nil
}
func (c *Client) UnlabelMessages(ctx context.Context, messageIDs []string, labelID string) error {
res, err := parallel.MapContext(
ctx,
runtime.NumCPU(),
xslices.Chunk(messageIDs, maxPageSize),
func(ctx context.Context, messageIDs []string) (LabelMessagesRes, error) {
var res LabelMessagesRes
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(LabelMessagesReq{
LabelID: labelID,
IDs: messageIDs,
}).SetResult(&res).Put("/mail/v4/messages/unlabel")
}); err != nil {
return LabelMessagesRes{}, err
}
return res, nil
},
)
if err != nil {
return err
}
if idx := xslices.IndexFunc(res, func(res LabelMessagesRes) bool { return !res.ok() }); idx >= 0 {
tokens := xslices.Map(res, func(res LabelMessagesRes) UndoToken {
return res.UndoToken
})
if _, undoErr := c.UndoActions(ctx, tokens...); undoErr != nil {
return fmt.Errorf("failed to undo actions: %w", undoErr)
}
return fmt.Errorf("failed to unlabel messages")
}
return nil
}
func (c *Client) getMessageIDs(ctx context.Context, afterID string) ([]string, error) {
var res struct {
IDs []string
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
if afterID != "" {
r = r.SetQueryParam("AfterID", afterID)
}
return r.SetQueryParam("Limit", strconv.Itoa(maxMessageIDs)).SetResult(&res).Get("/mail/v4/messages/ids")
}); err != nil {
return nil, err
}
return res.IDs, nil
}
func (c *Client) getMessageMetadata(ctx context.Context, page, pageSize int, filter MessageFilter) ([]MessageMetadata, error) {
var res struct {
Messages []MessageMetadata
Stale Bool
}
req := struct {
MessageFilter
Page int
PageSize int
Sort string
Desc Bool
}{
MessageFilter: filter,
Page: page,
PageSize: pageSize,
Sort: "ID",
Desc: false,
}
for {
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).SetHeader("X-HTTP-Method-Override", "GET").Post("/mail/v4/messages")
}); err != nil {
return nil, err
}
if !res.Stale {
break
}
}
return res.Messages, nil
}

318
message_build.go Normal file
View File

@@ -0,0 +1,318 @@
package proton
import (
"bufio"
"bytes"
"encoding/base64"
"io"
"mime"
"net/mail"
"strings"
"time"
"unicode/utf8"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/emersion/go-message"
"github.com/emersion/go-message/textproto"
"github.com/google/uuid"
)
func BuildRFC822(kr *crypto.KeyRing, msg Message, attData map[string][]byte) ([]byte, error) {
if msg.MIMEType == rfc822.MultipartMixed {
return buildPGPRFC822(kr, msg)
}
header, err := getMixedMessageHeader(msg)
if err != nil {
return nil, err
}
buf := new(bytes.Buffer)
w, err := message.CreateWriter(buf, header)
if err != nil {
return nil, err
}
var (
inlineAtts []Attachment
inlineData [][]byte
attachAtts []Attachment
attachData [][]byte
)
for _, att := range msg.Attachments {
if att.Disposition == InlineDisposition {
inlineAtts = append(inlineAtts, att)
inlineData = append(inlineData, attData[att.ID])
} else {
attachAtts = append(attachAtts, att)
attachData = append(attachData, attData[att.ID])
}
}
if len(inlineAtts) > 0 {
if err := writeRelatedParts(w, kr, msg, inlineAtts, inlineData); err != nil {
return nil, err
}
} else if err := writeTextPart(w, kr, msg); err != nil {
return nil, err
}
for i, att := range attachAtts {
if err := writeAttachmentPart(w, kr, att, attachData[i]); err != nil {
return nil, err
}
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func writeTextPart(w *message.Writer, kr *crypto.KeyRing, msg Message) error {
dec, err := msg.Decrypt(kr)
if err != nil {
return err
}
part, err := w.CreatePart(getTextPartHeader(dec, msg.MIMEType))
if err != nil {
return err
}
if _, err := part.Write(dec); err != nil {
return err
}
return part.Close()
}
func writeAttachmentPart(w *message.Writer, kr *crypto.KeyRing, att Attachment, attData []byte) error {
kps, err := base64.StdEncoding.DecodeString(att.KeyPackets)
if err != nil {
return err
}
msg := crypto.NewPGPSplitMessage(kps, attData).GetPGPMessage()
dec, err := kr.Decrypt(msg, nil, crypto.GetUnixTime())
if err != nil {
return err
}
part, err := w.CreatePart(getAttachmentPartHeader(att))
if err != nil {
return err
}
if _, err := part.Write(dec.GetBinary()); err != nil {
return err
}
return part.Close()
}
func writeRelatedParts(w *message.Writer, kr *crypto.KeyRing, msg Message, atts []Attachment, attData [][]byte) error {
var header message.Header
header.SetContentType(string(rfc822.MultipartRelated), nil)
rel, err := w.CreatePart(header)
if err != nil {
return err
}
if err := writeTextPart(rel, kr, msg); err != nil {
return err
}
for i, att := range atts {
if err := writeAttachmentPart(rel, kr, att, attData[i]); err != nil {
return err
}
}
return rel.Close()
}
func buildPGPRFC822(kr *crypto.KeyRing, msg Message) ([]byte, error) {
raw, err := textproto.ReadHeader(bufio.NewReader(strings.NewReader(msg.Header)))
if err != nil {
return nil, err
}
dec, err := msg.Decrypt(kr)
if err != nil {
return nil, err
}
sigs, err := ExtractSignatures(kr, msg.Body)
if err != nil {
return nil, err
}
if len(sigs) > 0 {
return buildMultipartSignedRFC822(message.Header{Header: raw}, dec, sigs[0])
}
return buildMultipartEncryptedRFC822(message.Header{Header: raw}, dec)
}
func buildMultipartSignedRFC822(header message.Header, body []byte, sig Signature) ([]byte, error) {
buf := new(bytes.Buffer)
boundary := uuid.New().String()
header.SetContentType("multipart/signed", map[string]string{
"micalg": sig.Hash,
"protocol": "application/pgp-signature",
"boundary": boundary,
})
if err := textproto.WriteHeader(buf, header.Header); err != nil {
return nil, err
}
w := rfc822.NewMultipartWriter(buf, boundary)
bodyHeader, bodyData := rfc822.Split(body)
if err := w.AddPart(func(w io.Writer) error {
if _, err := w.Write(bodyHeader); err != nil {
return err
}
if _, err := w.Write(bodyData); err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
var sigHeader message.Header
sigHeader.SetContentType("application/pgp-signature", map[string]string{"name": "OpenPGP_signature.asc"})
sigHeader.SetContentDisposition("attachment", map[string]string{"filename": "OpenPGP_signature"})
sigHeader.Set("Content-Description", "OpenPGP digital signature")
sigData, err := sig.Data.GetArmored()
if err != nil {
return nil, err
}
if err := w.AddPart(func(w io.Writer) error {
if err := textproto.WriteHeader(w, sigHeader.Header); err != nil {
return err
}
if _, err := w.Write([]byte(sigData)); err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
if err := w.Done(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func buildMultipartEncryptedRFC822(header message.Header, body []byte) ([]byte, error) {
buf := new(bytes.Buffer)
bodyHeader, bodyData := rfc822.Split(body)
parsedHeader, err := rfc822.NewHeader(bodyHeader)
if err != nil {
return nil, err
}
parsedHeader.Entries(func(key, val string) {
header.Set(key, val)
})
if err := textproto.WriteHeader(buf, header.Header); err != nil {
return nil, err
}
if _, err := buf.Write(bodyData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func getMixedMessageHeader(msg Message) (message.Header, error) {
raw, err := textproto.ReadHeader(bufio.NewReader(strings.NewReader(msg.Header)))
if err != nil {
return message.Header{}, err
}
header := message.Header{Header: raw}
header.SetContentType(string(rfc822.MultipartMixed), nil)
if date, err := mail.ParseDate(header.Get("Date")); err != nil || date.Before(time.Unix(0, 0)) {
if msgTime := time.Unix(msg.Time, 0); msgTime.After(time.Unix(0, 0)) {
header.Set("Date", msgTime.In(time.UTC).Format(time.RFC1123Z))
} else {
header.Del("Date")
}
header.Set("X-Original-Date", date.In(time.UTC).Format(time.RFC1123Z))
}
return header, nil
}
func getTextPartHeader(body []byte, mimeType rfc822.MIMEType) message.Header {
var header message.Header
params := make(map[string]string)
if utf8.Valid(body) {
params["charset"] = "utf-8"
}
header.SetContentType(string(mimeType), params)
// Use quoted-printable for all text/... parts
header.Set("Content-Transfer-Encoding", "quoted-printable")
return header
}
func getAttachmentPartHeader(att Attachment) message.Header {
var header message.Header
for key, val := range att.Headers {
for _, val := range val {
header.Add(key, val)
}
}
// All attachments have a content type.
header.SetContentType(string(att.MIMEType), map[string]string{"name": mime.QEncoding.Encode("utf-8", att.Name)})
// All attachments have a content disposition.
header.SetContentDisposition(string(att.Disposition), map[string]string{"filename": mime.QEncoding.Encode("utf-8", att.Name)})
// Use base64 for all attachments except embedded RFC822 messages.
if att.MIMEType != rfc822.MessageRFC822 {
header.Set("Content-Transfer-Encoding", "base64")
} else {
header.Del("Content-Transfer-Encoding")
}
return header
}

37
message_draft_types.go Normal file
View File

@@ -0,0 +1,37 @@
package proton
import (
"net/mail"
"github.com/ProtonMail/gluon/rfc822"
)
type DraftTemplate struct {
Subject string
Sender *mail.Address
ToList []*mail.Address
CCList []*mail.Address
BCCList []*mail.Address
Body string
MIMEType rfc822.MIMEType
ExternalID string `json:",omitempty"`
}
type CreateDraftAction int
const (
ReplyAction CreateDraftAction = iota
ReplyAllAction
ForwardAction
AutoResponseAction
ReadReceiptAction
)
type CreateDraftReq struct {
Message DraftTemplate
AttachmentKeyPackets []string
ParentID string `json:",omitempty"`
Action CreateDraftAction
}

123
message_encrypt.go Normal file
View File

@@ -0,0 +1,123 @@
package proton
import (
"bytes"
"io"
"mime"
"strings"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/google/uuid"
)
// CharsetReader returns a charset decoder for the given charset.
// If set, it will be used to decode non-utf8 encoded messages.
var CharsetReader func(charset string, input io.Reader) (io.Reader, error)
// EncryptRFC822 encrypts the given message literal as a PGP attachment.
func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) {
msg, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr)
if err != nil {
return nil, err
}
armored, err := msg.GetArmored()
if err != nil {
return nil, err
}
header, _ := rfc822.Split(literal)
headerParsed, err := rfc822.NewHeader(header)
if err != nil {
return nil, err
}
buf := new(bytes.Buffer)
boundary := strings.ReplaceAll(uuid.NewString(), "-", "")
multipartWriter := rfc822.NewMultipartWriter(buf, boundary)
{
newHeader := rfc822.NewEmptyHeader()
if value, ok := headerParsed.GetChecked("Message-Id"); ok {
newHeader.Set("Message-Id", value)
}
contentType := mime.FormatMediaType("multipart/encrypted", map[string]string{
"boundary": boundary,
"protocol": "application/pgp-encrypted",
})
newHeader.Set("Mime-version", "1.0")
newHeader.Set("Content-Type", contentType)
if value, ok := headerParsed.GetChecked("From"); ok {
newHeader.Set("From", value)
}
if value, ok := headerParsed.GetChecked("To"); ok {
newHeader.Set("To", value)
}
if value, ok := headerParsed.GetChecked("Subject"); ok {
newHeader.Set("Subject", value)
}
if value, ok := headerParsed.GetChecked("Date"); ok {
newHeader.Set("Date", value)
}
if value, ok := headerParsed.GetChecked("Received"); ok {
newHeader.Set("Received", value)
}
buf.Write(newHeader.Raw())
}
// Write PGP control data
{
pgpControlHeader := rfc822.NewEmptyHeader()
pgpControlHeader.Set("Content-Description", "PGP/MIME version identification")
pgpControlHeader.Set("Content-Type", "application/pgp-encrypted")
if err := multipartWriter.AddPart(func(writer io.Writer) error {
if _, err := writer.Write(pgpControlHeader.Raw()); err != nil {
return err
}
_, err := writer.Write([]byte("Version: 1"))
return err
}); err != nil {
return nil, err
}
}
// write PGP attachment
{
pgpAttachmentHeader := rfc822.NewEmptyHeader()
contentType := mime.FormatMediaType("application/octet-stream", map[string]string{
"name": "encrypted.asc",
})
pgpAttachmentHeader.Set("Content-Description", "OpenPGP encrypted message")
pgpAttachmentHeader.Set("Content-Disposition", "inline; filename=encrypted.asc")
pgpAttachmentHeader.Set("Content-Type", contentType)
if err := multipartWriter.AddPart(func(writer io.Writer) error {
if _, err := writer.Write(pgpAttachmentHeader.Raw()); err != nil {
return err
}
_, err := writer.Write([]byte(armored))
return err
}); err != nil {
return nil, err
}
}
// finish messsage
if err := multipartWriter.Done(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

96
message_encrypt_test.go Normal file
View File

@@ -0,0 +1,96 @@
package proton
import (
"bytes"
"testing"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEncryptMessage(t *testing.T) {
const message = `From: Nathaniel Borenstein <nsb@bellcore.com>
To: Ned Freed <ned@innosoft.com>
Subject: Sample message (import 2)
MIME-Version: 1.0
Content-type: multipart/mixed; boundary="simple boundary"
Received: from mail.protonmail.ch by mail.protonmail.ch; Tue, 25 Nov 2016
This is the preamble. It is to be ignored, though it
is a handy place for mail composers to include an
explanatory note to non-MIME compliant readers.
--simple boundary
This is implicitly typed plain ASCII text.
It does NOT end with a linebreak.
--simple boundary
Content-type: text/plain; charset=us-ascii
This is explicitly typed plain ASCII text.
It DOES end with a linebreak.
--simple boundary--
This is the epilogue. It is also to be ignored.
`
key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0)
require.NoError(t, err)
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
encryptedMessage, err := EncryptRFC822(kr, []byte(message))
require.NoError(t, err)
section := rfc822.Parse(encryptedMessage)
{
// Check root header:
header, err := section.ParseHeader()
require.NoError(t, err)
assert.Equal(t, header.Get("From"), "Nathaniel Borenstein <nsb@bellcore.com>")
assert.Equal(t, header.Get("To"), "Ned Freed <ned@innosoft.com>")
assert.Equal(t, header.Get("Subject"), "Sample message (import 2)")
assert.Equal(t, header.Get("MIME-Version"), "1.0")
assert.Equal(t, header.Get("Received"), "from mail.protonmail.ch by mail.protonmail.ch; Tue, 25 Nov 2016")
mediaType, params, err := rfc822.ParseMediaType(header.Get("Content-Type"))
require.NoError(t, err)
assert.Equal(t, "multipart/encrypted", mediaType)
assert.Equal(t, "application/pgp-encrypted", params["protocol"])
assert.NotEmpty(t, params["boundary"])
}
children, err := section.Children()
require.NoError(t, err)
require.Equal(t, 2, len(children))
{
// check first child.
child := children[0]
header, err := child.ParseHeader()
require.NoError(t, err)
assert.Equal(t, header.Get("Content-Description"), "PGP/MIME version identification")
assert.Equal(t, header.Get("Content-Type"), "application/pgp-encrypted")
assert.Equal(t, []byte("Version: 1"), child.Body())
}
{
// check second child.
child := children[1]
header, err := child.ParseHeader()
require.NoError(t, err)
assert.Equal(t, header.Get("Content-Description"), "OpenPGP encrypted message")
assert.Equal(t, header.Get("Content-Disposition"), "inline; filename=encrypted.asc")
assert.Equal(t, header.Get("Content-type"), "application/octet-stream; name=encrypted.asc")
body := child.Body()
assert.True(t, bytes.HasPrefix(body, []byte("-----BEGIN PGP MESSAGE-----")))
assert.True(t, bytes.HasSuffix(body, []byte("-----END PGP MESSAGE-----")))
}
}

84
message_import.go Normal file
View File

@@ -0,0 +1,84 @@
package proton
import (
"context"
"fmt"
"strconv"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/stream"
"github.com/bradenaw/juniper/xslices"
"github.com/go-resty/resty/v2"
)
const maxImportSize = 10
func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, workers, buffer int, req ...ImportReq) stream.Stream[ImportRes] {
return stream.Flatten(parallel.MapStream(
ctx,
stream.FromIterator(iterator.Chunk(iterator.Slice(req), maxImportSize)),
workers,
buffer,
func(ctx context.Context, req []ImportReq) (stream.Stream[ImportRes], error) {
res, err := c.importMessages(ctx, addrKR, req)
if err != nil {
return nil, fmt.Errorf("failed to import messages: %w", err)
}
for _, res := range res {
if res.Code != SuccessCode {
return nil, fmt.Errorf("failed to import message: %w", res.Error)
}
}
return stream.FromIterator(iterator.Slice(res)), nil
},
))
}
func (c *Client) importMessages(ctx context.Context, addrKR *crypto.KeyRing, req []ImportReq) ([]ImportRes, error) {
names := iterator.Collect(iterator.Map(iterator.Counter(len(req)), func(i int) string {
return strconv.Itoa(i)
}))
var named []namedImportReq
for idx, name := range names {
named = append(named, namedImportReq{
ImportReq: req[idx],
Name: name,
})
}
fields, err := buildImportReqFields(addrKR, named)
if err != nil {
return nil, err
}
type namedImportRes struct {
Name string
Response ImportRes
}
var res struct {
Responses []namedImportRes
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetMultipartFields(fields...).SetResult(&res).Post("/mail/v4/messages/import")
}); err != nil {
return nil, err
}
namedRes := make(map[string]ImportRes, len(res.Responses))
for _, res := range res.Responses {
namedRes[res.Name] = res.Response
}
return xslices.Map(names, func(name string) ImportRes {
return namedRes[name]
}), nil
}

68
message_import_types.go Normal file
View File

@@ -0,0 +1,68 @@
package proton
import (
"bytes"
"encoding/json"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
type ImportReq struct {
Metadata ImportMetadata
Message []byte
}
type namedImportReq struct {
ImportReq
Name string
}
type ImportMetadata struct {
AddressID string
LabelIDs []string
Unread Bool
Flags MessageFlag
}
type ImportRes struct {
Error
MessageID string
}
func buildImportReqFields(addrKR *crypto.KeyRing, req []namedImportReq) ([]*resty.MultipartField, error) {
var fields []*resty.MultipartField
metadata := make(map[string]ImportMetadata)
for _, req := range req {
metadata[req.Name] = req.Metadata
enc, err := EncryptRFC822(addrKR, req.Message)
if err != nil {
return nil, err
}
fields = append(fields, &resty.MultipartField{
Param: req.Name,
FileName: req.Name + ".eml",
ContentType: string(rfc822.MessageRFC822),
Reader: bytes.NewReader(append(enc, "\r\n"...)),
})
}
b, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
fields = append(fields, &resty.MultipartField{
Param: "Metadata",
ContentType: "application/json",
Reader: bytes.NewReader(b),
})
return fields, nil
}

35
message_send.go Normal file
View File

@@ -0,0 +1,35 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
)
func (c *Client) CreateDraft(ctx context.Context, req CreateDraftReq) (Message, error) {
var res struct {
Message Message
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages")
}); err != nil {
return Message{}, err
}
return res.Message, nil
}
func (c *Client) SendDraft(ctx context.Context, draftID string, req SendDraftReq) (Message, error) {
var res struct {
Sent Message
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages/" + draftID)
}); err != nil {
return Message{}, err
}
return res.Sent, nil
}

308
message_send_types.go Normal file
View File

@@ -0,0 +1,308 @@
package proton
import (
"encoding/base64"
"fmt"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
type EncryptionScheme int
const (
InternalScheme EncryptionScheme = 1 << iota
EncryptedOutsideScheme
ClearScheme
PGPInlineScheme
PGPMIMEScheme
ClearMIMEScheme
)
type SignatureType int
const (
NoSignature SignatureType = iota
DetachedSignature
AttachedSignature
)
type MessageRecipient struct {
Type EncryptionScheme
Signature SignatureType
BodyKeyPacket string `json:",omitempty"`
AttachmentKeyPackets map[string]string `json:",omitempty"`
}
type MessagePackage struct {
Addresses map[string]*MessageRecipient
MIMEType rfc822.MIMEType
Type EncryptionScheme
Body string
BodyKey *SessionKey `json:",omitempty"`
AttachmentKeys map[string]*SessionKey `json:",omitempty"`
}
func newMessagePackage(mimeType rfc822.MIMEType, encBodyData []byte) *MessagePackage {
return &MessagePackage{
Addresses: make(map[string]*MessageRecipient),
MIMEType: mimeType,
Body: base64.StdEncoding.EncodeToString(encBodyData),
AttachmentKeys: make(map[string]*SessionKey),
}
}
type SessionKey struct {
Key string
Algorithm string
}
func newSessionKey(key *crypto.SessionKey) *SessionKey {
return &SessionKey{
Key: key.GetBase64Key(),
Algorithm: key.Algo,
}
}
type SendPreferences struct {
// Encrypt indicates whether the email should be encrypted or not.
// If it's encrypted, we need to know which public key to use.
Encrypt bool
// PubKey contains an OpenPGP key that can be used for encryption.
PubKey *crypto.KeyRing
// SignatureType indicates how the email should be signed.
SignatureType SignatureType
// EncryptionScheme indicates if we should encrypt body and attachments separately and
// what MIME format to give the final encrypted email. The two standard PGP
// schemes are PGP/MIME and PGP/Inline. However we use a custom scheme for
// internal emails (including the so-called encrypted-to-outside emails,
// which even though meant for external users, they don't really get out of
// our platform). If the email is sent unencrypted, no PGP scheme is needed.
EncryptionScheme EncryptionScheme
// MIMEType is the MIME type to use for formatting the body of the email
// (before encryption/after decryption). The standard possibilities are the
// enriched HTML format, text/html, and plain text, text/plain. But it's
// also possible to have a multipart/mixed format, which is typically used
// for PGP/MIME encrypted emails, where attachments go into the body too.
// Because of this, this option is sometimes called MIME format.
MIMEType rfc822.MIMEType
}
type SendDraftReq struct {
Packages []*MessagePackage
}
func (req *SendDraftReq) AddMIMEPackage(
kr *crypto.KeyRing,
mimeBody string,
prefs map[string]SendPreferences,
) error {
for _, prefs := range prefs {
if prefs.MIMEType != rfc822.MultipartMixed {
return fmt.Errorf("invalid MIME type for MIME package: %s", prefs.MIMEType)
}
}
pkg, err := newMIMEPackage(kr, mimeBody, prefs)
if err != nil {
return err
}
req.Packages = append(req.Packages, pkg)
return nil
}
func (req *SendDraftReq) AddTextPackage(
kr *crypto.KeyRing,
body string,
mimeType rfc822.MIMEType,
prefs map[string]SendPreferences,
attKeys map[string]*crypto.SessionKey,
) error {
pkg, err := newTextPackage(kr, body, mimeType, prefs, attKeys)
if err != nil {
return err
}
req.Packages = append(req.Packages, pkg)
return nil
}
func newMIMEPackage(
kr *crypto.KeyRing,
mimeBody string,
prefs map[string]SendPreferences,
) (*MessagePackage, error) {
decBodyKey, encBodyData, err := encSplit(kr, mimeBody)
if err != nil {
return nil, fmt.Errorf("failed to encrypt MIME body: %w", err)
}
pkg := newMessagePackage(rfc822.MultipartMixed, encBodyData)
for addr, prefs := range prefs {
if prefs.MIMEType != rfc822.MultipartMixed {
return nil, fmt.Errorf("invalid MIME type for MIME package: %s", prefs.MIMEType)
}
if prefs.SignatureType != DetachedSignature {
return nil, fmt.Errorf("invalid signature type for MIME package: %d", prefs.SignatureType)
}
recipient := &MessageRecipient{
Type: prefs.EncryptionScheme,
Signature: prefs.SignatureType,
}
switch prefs.EncryptionScheme {
case PGPMIMEScheme:
if prefs.PubKey == nil {
return nil, fmt.Errorf("missing public key for %s", addr)
}
encBodyKey, err := prefs.PubKey.EncryptSessionKey(decBodyKey)
if err != nil {
return nil, fmt.Errorf("failed to encrypt session key: %w", err)
}
recipient.BodyKeyPacket = base64.StdEncoding.EncodeToString(encBodyKey)
case ClearMIMEScheme:
pkg.BodyKey = &SessionKey{
Key: decBodyKey.GetBase64Key(),
Algorithm: decBodyKey.Algo,
}
default:
return nil, fmt.Errorf("invalid encryption scheme for MIME package: %d", prefs.EncryptionScheme)
}
pkg.Addresses[addr] = recipient
pkg.Type |= prefs.EncryptionScheme
}
return pkg, nil
}
func newTextPackage(
kr *crypto.KeyRing,
body string,
mimeType rfc822.MIMEType,
prefs map[string]SendPreferences,
attKeys map[string]*crypto.SessionKey,
) (*MessagePackage, error) {
if mimeType != rfc822.TextPlain && mimeType != rfc822.TextHTML {
return nil, fmt.Errorf("invalid MIME type for package: %s", mimeType)
}
decBodyKey, encBodyData, err := encSplit(kr, body)
if err != nil {
return nil, fmt.Errorf("failed to encrypt message body: %w", err)
}
pkg := newMessagePackage(mimeType, encBodyData)
for addr, prefs := range prefs {
if prefs.MIMEType != mimeType {
return nil, fmt.Errorf("invalid MIME type for package: %s", prefs.MIMEType)
}
if prefs.SignatureType == DetachedSignature && !prefs.Encrypt {
if prefs.EncryptionScheme == PGPInlineScheme {
return nil, fmt.Errorf("invalid encryption scheme for %s: %d", addr, prefs.EncryptionScheme)
}
if prefs.EncryptionScheme == ClearScheme && mimeType != rfc822.TextPlain {
return nil, fmt.Errorf("invalid MIME type for clear package: %s", mimeType)
}
}
if prefs.EncryptionScheme == InternalScheme && !prefs.Encrypt {
return nil, fmt.Errorf("internal packages must be encrypted")
}
if prefs.EncryptionScheme == PGPInlineScheme && mimeType != rfc822.TextPlain {
return nil, fmt.Errorf("invalid MIME type for PGP inline package: %s", mimeType)
}
switch prefs.EncryptionScheme {
case ClearScheme:
pkg.BodyKey = newSessionKey(decBodyKey)
for attID, attKey := range attKeys {
pkg.AttachmentKeys[attID] = newSessionKey(attKey)
}
case InternalScheme, PGPInlineScheme:
// ...
default:
return nil, fmt.Errorf("invalid encryption scheme for package: %d", prefs.EncryptionScheme)
}
recipient := &MessageRecipient{
Type: prefs.EncryptionScheme,
Signature: prefs.SignatureType,
AttachmentKeyPackets: make(map[string]string),
}
if prefs.Encrypt {
if prefs.PubKey == nil {
return nil, fmt.Errorf("missing public key for %s", addr)
}
if prefs.SignatureType != DetachedSignature {
return nil, fmt.Errorf("invalid signature type for package: %d", prefs.SignatureType)
}
encBodyKey, err := prefs.PubKey.EncryptSessionKey(decBodyKey)
if err != nil {
return nil, fmt.Errorf("failed to encrypt session key: %w", err)
}
recipient.BodyKeyPacket = base64.StdEncoding.EncodeToString(encBodyKey)
for attID, attKey := range attKeys {
encAttKey, err := prefs.PubKey.EncryptSessionKey(attKey)
if err != nil {
return nil, fmt.Errorf("failed to encrypt attachment key: %w", err)
}
recipient.AttachmentKeyPackets[attID] = base64.StdEncoding.EncodeToString(encAttKey)
}
}
pkg.Addresses[addr] = recipient
pkg.Type |= prefs.EncryptionScheme
}
return pkg, nil
}
func encSplit(kr *crypto.KeyRing, body string) (*crypto.SessionKey, []byte, error) {
encBody, err := kr.Encrypt(crypto.NewPlainMessageFromString(body), kr)
if err != nil {
return nil, nil, fmt.Errorf("failed to encrypt MIME body: %w", err)
}
splitEncBody, err := encBody.SplitMessage()
if err != nil {
return nil, nil, fmt.Errorf("failed to split message: %w", err)
}
decBodyKey, err := kr.DecryptSessionKey(splitEncBody.GetBinaryKeyPacket())
if err != nil {
return nil, nil, fmt.Errorf("failed to decrypt session key: %w", err)
}
return decBodyKey, splitEncBody.GetBinaryDataPacket(), nil
}

381
message_send_types_test.go Normal file
View File

@@ -0,0 +1,381 @@
package proton
import (
"testing"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require"
)
func TestSendDraftReq_AddMIMEPackage(t *testing.T) {
key, err := crypto.GenerateKey("name", "email", "rsa", 2048)
require.NoError(t, err)
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
tests := []struct {
name string
mimeBody string
prefs map[string]SendPreferences
wantErr bool
}{
{
name: "Clear MIME with detached signature",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-sign@email.com": {
Encrypt: false,
SignatureType: DetachedSignature,
EncryptionScheme: ClearMIMEScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: false,
},
{
name: "Clear MIME with no signature (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-no-sign@email.com": {
Encrypt: false,
SignatureType: NoSignature,
EncryptionScheme: ClearMIMEScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: true,
},
{
name: "Clear MIME with plain text (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-plain@email.com": {
Encrypt: false,
SignatureType: DetachedSignature,
EncryptionScheme: ClearMIMEScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: true,
},
{
name: "Clear MIME with rich text (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-html@email.com": {
Encrypt: false,
SignatureType: DetachedSignature,
EncryptionScheme: ClearMIMEScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "PGP MIME with detached signature",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-encrypted@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: PGPMIMEScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: false,
},
{
name: "PGP MIME with plain text (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-encrypted-plain@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: PGPMIMEScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: true,
},
{
name: "PGP MIME with rich text (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-encrypted-plain@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: PGPMIMEScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "PGP MIME with missing public key (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-encrypted-no-pubkey@email.com": {
Encrypt: true,
SignatureType: DetachedSignature,
EncryptionScheme: PGPMIMEScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: true,
},
{
name: "PGP MIME with no signature (error)",
mimeBody: "this is a mime body",
prefs: map[string]SendPreferences{"mime-encrypted-no-signature@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: NoSignature,
EncryptionScheme: PGPMIMEScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req SendDraftReq
if err := req.AddMIMEPackage(kr, tt.mimeBody, tt.prefs); (err != nil) != tt.wantErr {
t.Errorf("SendDraftReq.AddMIMEPackage() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestSendDraftReq_AddPackage(t *testing.T) {
key, err := crypto.GenerateKey("name", "email", "rsa", 2048)
require.NoError(t, err)
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
tests := []struct {
name string
body string
mimeType rfc822.MIMEType
prefs map[string]SendPreferences
attKeys map[string]*crypto.SessionKey
wantErr bool
}{
{
name: "internal plain text with detached signature",
body: "this is a text/plain body",
mimeType: rfc822.TextPlain,
prefs: map[string]SendPreferences{"internal-plain@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: false,
},
{
name: "internal rich text with detached signature",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"internal-html@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: false,
},
{
name: "internal rich text with bad package content type (error)",
body: "this is a text/html body",
mimeType: "bad content type",
prefs: map[string]SendPreferences{"internal-bad-package-content-type@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "internal rich text with bad recipient content type (error)",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"internal-bad-recipient-content-type@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: "bad content type",
}},
wantErr: true,
},
{
name: "internal with multipart (error)",
body: "this is a text/html body",
mimeType: rfc822.MultipartMixed,
prefs: map[string]SendPreferences{"internal-multipart-mixed@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: true,
},
{
name: "internal without encryption (error)",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"internal-no-encrypt@email.com": {
Encrypt: false,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "internal without pubkey (error)",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"internal-no-pubkey@email.com": {
Encrypt: true,
SignatureType: DetachedSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "internal without signature (error)",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"internal-no-sig@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: NoSignature,
EncryptionScheme: InternalScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "clear rich text without signature",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"clear-rich@email.com": {
Encrypt: false,
SignatureType: NoSignature,
EncryptionScheme: ClearScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: false,
},
{
name: "clear plain text without signature",
body: "this is a text/plain body",
mimeType: rfc822.TextPlain,
prefs: map[string]SendPreferences{"clear-plain@email.com": {
Encrypt: false,
SignatureType: NoSignature,
EncryptionScheme: ClearScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: false,
},
{
name: "clear plain text with signature",
body: "this is a text/plain body",
mimeType: rfc822.TextPlain,
prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": {
Encrypt: false,
SignatureType: DetachedSignature,
EncryptionScheme: ClearScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: false,
},
{
name: "clear plain text with bad scheme (error)",
body: "this is a text/plain body",
mimeType: rfc822.TextPlain,
prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": {
Encrypt: false,
SignatureType: DetachedSignature,
EncryptionScheme: PGPInlineScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: true,
},
{
name: "clear rich text with signature (error)",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": {
Encrypt: false,
SignatureType: DetachedSignature,
EncryptionScheme: ClearScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "encrypted plain text with signature",
body: "this is a text/plain body",
mimeType: rfc822.TextPlain,
prefs: map[string]SendPreferences{"pgp-inline-with-sig@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: PGPInlineScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: false,
},
{
name: "encrypted html text with signature (error)",
body: "this is a text/html body",
mimeType: rfc822.TextHTML,
prefs: map[string]SendPreferences{"pgp-inline-rich-with-sig@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: PGPInlineScheme,
MIMEType: rfc822.TextHTML,
}},
wantErr: true,
},
{
name: "encrypted mixed text with signature (error)",
body: "this is a multipart/mixed body",
mimeType: rfc822.MultipartMixed,
prefs: map[string]SendPreferences{"pgp-inline-mixed-with-sig@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: PGPInlineScheme,
MIMEType: rfc822.MultipartMixed,
}},
wantErr: true,
},
{
name: "encrypted for outside (error)",
body: "this is a text/plain body",
mimeType: rfc822.TextPlain,
prefs: map[string]SendPreferences{"enc-for-outside@email.com": {
Encrypt: true,
PubKey: kr,
SignatureType: DetachedSignature,
EncryptionScheme: EncryptedOutsideScheme,
MIMEType: rfc822.TextPlain,
}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req SendDraftReq
if err := req.AddTextPackage(kr, tt.body, tt.mimeType, tt.prefs, tt.attKeys); (err != nil) != tt.wantErr {
t.Errorf("SendDraftReq.AddPackage() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

192
message_types.go Normal file
View File

@@ -0,0 +1,192 @@
package proton
import (
"net/mail"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"golang.org/x/exp/slices"
)
type MessageMetadata struct {
ID string
AddressID string
LabelIDs []string
ExternalID string
Subject string
Sender *mail.Address
ToList []*mail.Address
CCList []*mail.Address
BCCList []*mail.Address
ReplyTos []*mail.Address
Flags MessageFlag
Time int64
Size int
Unread Bool
IsReplied Bool
IsRepliedAll Bool
IsForwarded Bool
}
func (meta MessageMetadata) Seen() bool {
return !bool(meta.Unread)
}
func (meta MessageMetadata) Starred() bool {
return slices.Contains(meta.LabelIDs, StarredLabel)
}
func (meta MessageMetadata) IsDraft() bool {
return meta.Flags&(MessageFlagReceived|MessageFlagSent) == 0
}
type MessageFilter struct {
ID []string `json:",omitempty"`
AddressID string `json:",omitempty"`
ExternalID string `json:",omitempty"`
LabelID string `json:",omitempty"`
}
type Message struct {
MessageMetadata
Header string
ParsedHeaders Headers
Body string
MIMEType rfc822.MIMEType
Attachments []Attachment
}
type MessageFlag int64
const (
MessageFlagReceived MessageFlag = 1 << iota
MessageFlagSent
MessageFlagInternal
MessageFlagE2E
MessageFlagAuto
MessageFlagReplied
MessageFlagRepliedAll
MessageFlagForwarded
MessageFlagAutoReplied
MessageFlagImported
MessageFlagOpened
MessageFlagReceiptSent
MessageFlagNotified
MessageFlagTouched
MessageFlagReceipt
MessageFlagProton
MessageFlagReceiptRequest
MessageFlagPublicKey
MessageFlagSign
MessageFlagUnsubscribed
MessageFlagSPFFail
MessageFlagDKIMFail
MessageFlagDMARCFail
MessageFlagHamManual
MessageFlagSpamAuto
MessageFlagSpamManual
MessageFlagPhishingAuto
MessageFlagPhishingManual
)
func (f MessageFlag) Has(flag MessageFlag) bool {
return f&flag != 0
}
func (f MessageFlag) Matches(flag MessageFlag) bool {
return f&flag == flag
}
func (f MessageFlag) HasAny(flags ...MessageFlag) bool {
for _, flag := range flags {
if f.Has(flag) {
return true
}
}
return false
}
func (f MessageFlag) HasAll(flags ...MessageFlag) bool {
for _, flag := range flags {
if !f.Has(flag) {
return false
}
}
return true
}
func (f MessageFlag) Add(flag MessageFlag) MessageFlag {
return f | flag
}
func (f MessageFlag) Remove(flag MessageFlag) MessageFlag {
return f &^ flag
}
func (f MessageFlag) Toggle(flag MessageFlag) MessageFlag {
if f.Has(flag) {
return f.Remove(flag)
}
return f.Add(flag)
}
func (m Message) Decrypt(kr *crypto.KeyRing) ([]byte, error) {
enc, err := crypto.NewPGPMessageFromArmored(m.Body)
if err != nil {
return nil, err
}
dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime())
if err != nil {
return nil, err
}
return dec.GetBinary(), nil
}
type FullMessage struct {
Message
AttData [][]byte
}
type Signature struct {
Hash string
Data *crypto.PGPSignature
}
type MessageActionReq struct {
IDs []string
}
type LabelMessagesReq struct {
LabelID string
IDs []string
}
type LabelMessagesRes struct {
Responses []LabelMessageRes
UndoToken UndoToken
}
func (res LabelMessagesRes) ok() bool {
for _, resp := range res.Responses {
if resp.Response.Code != SuccessCode {
return false
}
}
return true
}
type LabelMessageRes struct {
ID string
Response Error
}

49
message_types_test.go Normal file
View File

@@ -0,0 +1,49 @@
package proton
import (
"os"
"testing"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/stretchr/testify/require"
)
func TestDecrypt(t *testing.T) {
body, err := os.ReadFile("testdata/body.pgp")
require.NoError(t, err)
pubKR := loadKeyRing(t, "testdata/pub.asc", nil)
prvKR := loadKeyRing(t, "testdata/prv.asc", []byte("password"))
msg := Message{Body: string(body)}
sigs, err := ExtractSignatures(prvKR, msg.Body)
require.NoError(t, err)
enc, err := crypto.NewPGPMessageFromArmored(msg.Body)
require.NoError(t, err)
dec, err := prvKR.Decrypt(enc, nil, crypto.GetUnixTime())
require.NoError(t, err)
require.NoError(t, pubKR.VerifyDetached(dec, sigs[0].Data, crypto.GetUnixTime()))
}
func loadKeyRing(t *testing.T, file string, pass []byte) *crypto.KeyRing {
f, err := os.Open(file)
require.NoError(t, err)
defer f.Close()
key, err := crypto.NewKeyFromArmoredReader(f)
require.NoError(t, err)
if pass != nil {
key, err = key.Unlock(pass)
require.NoError(t, err)
}
kr, err := crypto.NewKeyRing(key)
require.NoError(t, err)
return kr
}

138
option.go Normal file
View File

@@ -0,0 +1,138 @@
package proton
import (
"net/http"
"github.com/go-resty/resty/v2"
)
// Option represents a type that can be used to configure the manager.
type Option interface {
config(*managerBuilder)
}
func WithHostURL(hostURL string) Option {
return &withHostURL{
hostURL: hostURL,
}
}
type withHostURL struct {
hostURL string
}
func (opt withHostURL) config(builder *managerBuilder) {
builder.hostURL = opt.hostURL
}
func WithAppVersion(appVersion string) Option {
return &withAppVersion{
appVersion: appVersion,
}
}
type withAppVersion struct {
appVersion string
}
func (opt withAppVersion) config(builder *managerBuilder) {
builder.appVersion = opt.appVersion
}
func WithTransport(transport http.RoundTripper) Option {
return &withTransport{
transport: transport,
}
}
type withTransport struct {
transport http.RoundTripper
}
func (opt withTransport) config(builder *managerBuilder) {
builder.transport = opt.transport
}
type withAttPoolSize struct {
attPoolSize int
}
func (opt withAttPoolSize) config(builder *managerBuilder) {
builder.attPoolSize = opt.attPoolSize
}
func WithAttPoolSize(attPoolSize int) Option {
return &withAttPoolSize{
attPoolSize: attPoolSize,
}
}
type withSkipVerifyProofs struct {
skipVerifyProofs bool
}
func (opt withSkipVerifyProofs) config(builder *managerBuilder) {
builder.verifyProofs = !opt.skipVerifyProofs
}
func WithSkipVerifyProofs() Option {
return &withSkipVerifyProofs{
skipVerifyProofs: true,
}
}
func WithRetryCount(retryCount int) Option {
return &withRetryCount{
retryCount: retryCount,
}
}
type withRetryCount struct {
retryCount int
}
func (opt withRetryCount) config(builder *managerBuilder) {
builder.retryCount = opt.retryCount
}
func WithCookieJar(jar http.CookieJar) Option {
return &withCookieJar{
jar: jar,
}
}
type withCookieJar struct {
jar http.CookieJar
}
func (opt withCookieJar) config(builder *managerBuilder) {
builder.cookieJar = opt.jar
}
func WithLogger(logger resty.Logger) Option {
return &withLogger{
logger: logger,
}
}
type withLogger struct {
logger resty.Logger
}
func (opt withLogger) config(builder *managerBuilder) {
builder.logger = opt.logger
}
func WithDebug(debug bool) Option {
return &withDebug{
debug: debug,
}
}
type withDebug struct {
debug bool
}
func (opt withDebug) config(builder *managerBuilder) {
builder.debug = opt.debug
}

2
package.go Normal file
View File

@@ -0,0 +1,2 @@
// Package proton implements types for accessing the Proton API.
package proton

33
paging.go Normal file
View File

@@ -0,0 +1,33 @@
package proton
import (
"context"
"runtime"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/stream"
)
const maxPageSize = 150
func fetchPaged[T any](
ctx context.Context,
total, pageSize int,
fn func(ctx context.Context, page, pageSize int) ([]T, error),
) ([]T, error) {
return stream.Collect(ctx, stream.Flatten(parallel.MapStream(
ctx,
stream.FromIterator(iterator.Counter(total/pageSize+1)),
runtime.NumCPU(),
runtime.NumCPU(),
func(ctx context.Context, page int) (stream.Stream[T], error) {
values, err := fn(ctx, page, pageSize)
if err != nil {
return nil, err
}
return stream.FromIterator(iterator.Slice(values)), nil
},
)))
}

166
pool.go Normal file
View File

@@ -0,0 +1,166 @@
package proton
import (
"context"
"errors"
"fmt"
"sync"
"github.com/ProtonMail/gluon/queue"
)
// ErrJobCancelled indicates the job was cancelled.
var ErrJobCancelled = errors.New("job cancelled by surrounding context")
// Pool is a worker pool that handles input of type In and returns results of type Out.
type Pool[In comparable, Out any] struct {
queue *queue.QueuedChannel[*job[In, Out]]
wg sync.WaitGroup
}
// doneFunc must be called to free up pool resources.
type doneFunc func()
// New returns a new pool.
func NewPool[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
pool := &Pool[In, Out]{
queue: queue.NewQueuedChannel[*job[In, Out]](0, 0),
}
for i := 0; i < size; i++ {
pool.wg.Add(1)
go func() {
defer pool.wg.Done()
for job := range pool.queue.GetChannel() {
select {
case <-job.ctx.Done():
job.postFailure(ErrJobCancelled)
default:
res, err := work(job.ctx, job.req)
if err != nil {
job.postFailure(err)
} else {
job.postSuccess(res)
}
job.waitDone()
}
}
}()
}
return pool
}
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(int, In, Out, error) error) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var (
wg sync.WaitGroup
errList []error
lock sync.Mutex
)
for i, req := range reqs {
req := req
wg.Add(1)
go func(index int) {
defer wg.Done()
job, done, err := pool.newJob(ctx, req)
if err != nil {
lock.Lock()
defer lock.Unlock()
// Cancel ongoing jobs.
cancel()
// Collect the error.
errList = append(errList, err)
return
}
defer done()
res, err := job.result()
if err := fn(index, req, res, err); err != nil {
lock.Lock()
defer lock.Unlock()
// Cancel ongoing jobs.
cancel()
// Collect the error.
errList = append(errList, err)
}
}(i)
}
wg.Wait()
// TODO: Join the errors somehow?
if len(errList) > 0 {
return errList[0]
}
return nil
}
// ProcessAll submits jobs to the pool. All results are returned once available.
func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) ([]Out, error) {
data := make([]Out, len(reqs))
if err := pool.Process(ctx, reqs, func(index int, req In, res Out, err error) error {
if err != nil {
return err
}
data[index] = res
return nil
}); err != nil {
return nil, err
}
return data, nil
}
// ProcessOne submits one job to the pool and returns the result.
func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) {
job, done, err := pool.newJob(ctx, req)
if err != nil {
var o Out
return o, err
}
defer done()
return job.result()
}
func (pool *Pool[In, Out]) Done() {
pool.queue.Close()
pool.wg.Wait()
}
// newJob submits a job to the pool. It returns a job handle and a DoneFunc.
// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done,
// which frees up the worker in the pool for reuse.
func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc, error) {
job := newJob[In, Out](ctx, req)
if !pool.queue.Enqueue(job) {
return nil, nil, fmt.Errorf("pool closed")
}
return job, func() { close(job.done) }, nil
}

173
pool_test.go Normal file
View File

@@ -0,0 +1,173 @@
package proton
import (
"context"
"errors"
"runtime"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPool_NewJob(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
defer doubler.Done()
job1, done1, err := doubler.newJob(context.Background(), 1)
require.NoError(t, err)
defer done1()
job2, done2, err := doubler.newJob(context.Background(), 2)
require.NoError(t, err)
defer done2()
res2, err := job2.result()
require.NoError(t, err)
res1, err := job1.result()
require.NoError(t, err)
assert.Equal(t, 2, res1)
assert.Equal(t, 4, res2)
}
func TestPool_NewJob_Done(t *testing.T) {
// Create a doubler pool with 2 workers.
doubler := newDoubler(2)
defer doubler.Done()
// Start two jobs. Don't mark the jobs as done yet.
job1, done1, err := doubler.newJob(context.Background(), 1)
require.NoError(t, err)
job2, done2, err := doubler.newJob(context.Background(), 2)
require.NoError(t, err)
// Get the first result.
res1, _ := job1.result()
assert.Equal(t, 2, res1)
// Get the first result.
res2, _ := job2.result()
assert.Equal(t, 4, res2)
// Additional jobs will wait.
job3, done3, err := doubler.newJob(context.Background(), 3)
require.NoError(t, err)
job4, done4, err := doubler.newJob(context.Background(), 4)
require.NoError(t, err)
// Channel to collect results from jobs 3 and 4.
resCh := make(chan int, 2)
go func() {
res, _ := job3.result()
resCh <- res
}()
go func() {
res, _ := job4.result()
resCh <- res
}()
// Mark jobs 1 and 2 as done, freeing up the workers.
done1()
done2()
assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh})
// Mark jobs 3 and 4 as done, freeing up the workers.
done3()
done4()
}
func TestPool_Process(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
defer doubler.Done()
res := make([]int, 5)
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(index, reqVal, resVal int, err error) error {
require.NoError(t, err)
res[index] = resVal
return nil
}))
assert.Equal(t, []int{
2,
4,
6,
8,
10,
}, res)
}
func TestPool_Process_Error(t *testing.T) {
doubler := newDoublerWithError(runtime.NumCPU())
defer doubler.Done()
assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_int, _ int, _ int, err error) error {
return err
}))
}
func TestPool_Process_Parallel(t *testing.T) {
doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond)
defer doubler.Done()
var wg sync.WaitGroup
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, _ int, err error) error {
return nil
}))
}()
}
wg.Wait()
}
func TestPool_ProcessAll(t *testing.T) {
doubler := newDoubler(runtime.NumCPU())
defer doubler.Done()
res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5})
require.NoError(t, err)
assert.Equal(t, []int{
2,
4,
6,
8,
10,
}, res)
}
func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] {
return NewPool(workers, func(ctx context.Context, req int) (int, error) {
if len(delay) > 0 {
time.Sleep(delay[0])
}
return 2 * req, nil
})
}
func newDoublerWithError(workers int) *Pool[int, int] {
return NewPool(workers, func(ctx context.Context, req int) (int, error) {
if req%2 == 0 {
return 0, errors.New("oops")
}
return 2 * req, nil
})
}

84
response.go Normal file
View File

@@ -0,0 +1,84 @@
package proton
import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/go-resty/resty/v2"
)
type Code int
const (
SuccessCode Code = 1000
MultiCode Code = 1001
InvalidValue Code = 2001
AppVersionMissingCode Code = 5001
AppVersionBadCode Code = 5003
PasswordWrong Code = 8002
HumanVerificationRequired Code = 9001
PaidPlanRequired Code = 10004
AuthRefreshTokenInvalid Code = 10013
)
type Error struct {
Code Code
Message string `json:"Error"`
}
func (err Error) Error() string {
return err.Message
}
func catchAPIError(_ *resty.Client, res *resty.Response) error {
if !res.IsError() {
return nil
}
var err error
if apiErr, ok := res.Error().(*Error); ok {
err = apiErr
} else {
err = fmt.Errorf("%v", res.Status())
}
return fmt.Errorf("%v: %w", res.StatusCode(), err)
}
func updateTime(_ *resty.Client, res *resty.Response) error {
date, err := time.Parse(time.RFC1123, res.Header().Get("Date"))
if err != nil {
return err
}
crypto.UpdateTime(date.Unix())
return nil
}
func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) {
if res.StatusCode() == http.StatusTooManyRequests {
if after := res.Header().Get("Retry-After"); after != "" {
seconds, err := strconv.Atoi(after)
if err != nil {
return 0, err
}
return time.Duration(seconds) * time.Second, nil
}
}
return 0, nil
}
func catchTooManyRequests(res *resty.Response, _ error) bool {
return res.StatusCode() == http.StatusTooManyRequests
}
func catchDialError(res *resty.Response, err error) bool {
return res.RawResponse == nil
}

21
salt.go Normal file
View File

@@ -0,0 +1,21 @@
package proton
import (
"context"
"github.com/go-resty/resty/v2"
)
func (c *Client) GetSalts(ctx context.Context) (Salts, error) {
var res struct {
KeySalts []Salt
}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetResult(&res).Get("/core/v4/keys/salts")
}); err != nil {
return nil, err
}
return res.KeySalts, nil
}

37
salt_types.go Normal file
View File

@@ -0,0 +1,37 @@
package proton
import (
"encoding/base64"
"fmt"
"github.com/ProtonMail/go-srp"
"github.com/bradenaw/juniper/xslices"
)
type Salt struct {
ID, KeySalt string
}
type Salts []Salt
func (salts Salts) SaltForKey(keyPass []byte, keyID string) ([]byte, error) {
idx := xslices.IndexFunc(salts, func(salt Salt) bool {
return salt.ID == keyID
})
if idx < 0 {
return nil, fmt.Errorf("no salt found for key %s", keyID)
}
keySalt, err := base64.StdEncoding.DecodeString(salts[idx].KeySalt)
if err != nil {
return nil, err
}
saltedKeyPass, err := srp.MailboxPassword(keyPass, keySalt)
if err != nil {
return nil, nil
}
return saltedKeyPass[len(saltedKeyPass)-31:], nil
}

74
server/addresses.go Normal file
View File

@@ -0,0 +1,74 @@
package server
import (
"net/http"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xslices"
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
)
func (s *Server) handleGetAddresses() gin.HandlerFunc {
return func(c *gin.Context) {
addresses, err := s.b.GetAddresses(c.GetString("UserID"))
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{
"Addresses": addresses,
})
}
}
func (s *Server) handleGetAddress() gin.HandlerFunc {
return func(c *gin.Context) {
addresses, err := s.b.GetAddresses(c.GetString("UserID"))
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{
"Address": addresses[xslices.IndexFunc(addresses, func(address proton.Address) bool {
return address.ID == c.Param("addressID")
})],
})
}
}
func (s *Server) handlePutAddressesOrder() gin.HandlerFunc {
return func(c *gin.Context) {
var req proton.OrderAddressesReq
if err := c.BindJSON(&req); err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
addresses, err := s.b.GetAddresses(c.GetString("UserID"))
if err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if len(req.AddressIDs) != len(addresses) {
c.AbortWithStatus(http.StatusBadRequest)
return
}
for _, address := range addresses {
if !slices.Contains(req.AddressIDs, address.ID) {
c.AbortWithStatus(http.StatusBadRequest)
return
}
}
if err := s.b.SetAddressOrder(c.GetString("UserID"), req.AddressIDs); err != nil {
c.AbortWithStatus(http.StatusInternalServerError)
return
}
}
}

66
server/attachments.go Normal file
View File

@@ -0,0 +1,66 @@
package server
import (
"io"
"mime/multipart"
"net/http"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/gin-gonic/gin"
)
func (s *Server) handlePostMailAttachments() gin.HandlerFunc {
return func(c *gin.Context) {
form, err := c.MultipartForm()
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
attachment, err := s.b.CreateAttachment(
c.GetString("UserID"),
form.Value["MessageID"][0],
form.Value["Filename"][0],
rfc822.MIMEType(form.Value["MIMEType"][0]),
proton.Disposition(form.Value["Disposition"][0]),
mustReadFileHeader(form.File["KeyPackets"][0]),
mustReadFileHeader(form.File["DataPacket"][0]),
string(mustReadFileHeader(form.File["Signature"][0])),
)
if err != nil {
_ = c.AbortWithError(http.StatusUnprocessableEntity, err)
return
}
c.JSON(http.StatusOK, gin.H{
"Attachment": attachment,
})
}
}
func (s *Server) handleGetMailAttachment() gin.HandlerFunc {
return func(c *gin.Context) {
attData, err := s.b.GetAttachment(c.Param("attachID"))
if err != nil {
_ = c.AbortWithError(http.StatusUnprocessableEntity, err)
return
}
c.Data(http.StatusOK, "application/octet-stream", attData)
}
}
func mustReadFileHeader(fh *multipart.FileHeader) []byte {
f, err := fh.Open()
if err != nil {
panic(err)
}
data, err := io.ReadAll(f)
if err != nil {
panic(err)
}
return data
}

124
server/auth.go Normal file
View File

@@ -0,0 +1,124 @@
package server
import (
"encoding/base64"
"net/http"
"github.com/ProtonMail/go-proton-api"
"github.com/gin-gonic/gin"
)
func (s *Server) handlePostAuthInfo() gin.HandlerFunc {
return func(c *gin.Context) {
var req proton.AuthInfoReq
if err := c.BindJSON(&req); err != nil {
return
}
info, err := s.b.NewAuthInfo(req.Username)
if err != nil {
_ = c.AbortWithError(http.StatusUnauthorized, err)
return
}
c.JSON(http.StatusOK, info)
}
}
func (s *Server) handlePostAuth() gin.HandlerFunc {
return func(c *gin.Context) {
var req proton.AuthReq
if err := c.BindJSON(&req); err != nil {
return
}
clientEphemeral, err := base64.StdEncoding.DecodeString(req.ClientEphemeral)
if err != nil {
_ = c.AbortWithError(http.StatusBadRequest, err)
return
}
clientProof, err := base64.StdEncoding.DecodeString(req.ClientProof)
if err != nil {
_ = c.AbortWithError(http.StatusBadRequest, err)
return
}
auth, err := s.b.NewAuth(req.Username, clientEphemeral, clientProof, req.SRPSession)
if err != nil {
_ = c.AbortWithError(http.StatusUnauthorized, err)
return
}
c.JSON(http.StatusOK, auth)
}
}
func (s *Server) handlePostAuthRefresh() gin.HandlerFunc {
return func(c *gin.Context) {
var req proton.AuthRefreshReq
if err := c.BindJSON(&req); err != nil {
return
}
auth, err := s.b.NewAuthRef(req.UID, req.RefreshToken)
if err != nil {
_ = c.AbortWithError(http.StatusUnauthorized, err)
return
}
c.JSON(http.StatusOK, auth)
}
}
func (s *Server) handleDeleteAuth() gin.HandlerFunc {
return func(c *gin.Context) {
if err := s.b.DeleteSession(c.GetString("UserID"), c.GetString("AuthUID")); err != nil {
_ = c.AbortWithError(http.StatusUnauthorized, err)
return
}
}
}
func (s *Server) handleGetAuthSessions() gin.HandlerFunc {
return func(c *gin.Context) {
sessions, err := s.b.GetSessions(c.GetString("UserID"))
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.JSON(http.StatusOK, gin.H{"Sessions": sessions})
}
}
func (s *Server) handleDeleteAuthSessions() gin.HandlerFunc {
return func(c *gin.Context) {
sessions, err := s.b.GetSessions(c.GetString("UserID"))
if err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
for _, session := range sessions {
if session.UID != c.GetString("AuthUID") {
if err := s.b.DeleteSession(c.GetString("UserID"), session.UID); err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
}
}
}
}
func (s *Server) handleDeleteAuthSession() gin.HandlerFunc {
return func(c *gin.Context) {
if err := s.b.DeleteSession(c.GetString("UserID"), c.Param("authUID")); err != nil {
_ = c.AbortWithError(http.StatusInternalServerError, err)
return
}
}
}

87
server/backend/account.go Normal file
View File

@@ -0,0 +1,87 @@
package backend
import (
"sync"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
)
type account struct {
userID string
username string
addresses map[string]*address
auth map[string]auth
authLock sync.RWMutex
keys []key
salt []byte
verifier []byte
labelIDs []string
messageIDs []string
updateIDs []ID
}
func newAccount(userID, username string, armKey string, salt, verifier []byte) *account {
return &account{
userID: userID,
username: username,
addresses: make(map[string]*address),
auth: make(map[string]auth),
keys: []key{{keyID: uuid.NewString(), key: armKey}},
salt: salt,
verifier: verifier,
}
}
func (acc *account) toUser() proton.User {
return proton.User{
ID: acc.userID,
Name: acc.username,
DisplayName: acc.username,
Email: acc.primary().email,
Keys: xslices.Map(acc.keys, func(key key) proton.Key {
privKey, err := crypto.NewKeyFromArmored(key.key)
if err != nil {
panic(err)
}
rawKey, err := privKey.Serialize()
if err != nil {
panic(err)
}
return proton.Key{
ID: key.keyID,
PrivateKey: rawKey,
Primary: key == acc.keys[0],
Active: true,
}
}),
}
}
func (acc *account) primary() *address {
for _, addr := range acc.addresses {
if addr.order == 1 {
return addr
}
}
panic("no primary address")
}
func (acc *account) getAddr(email string) (*address, bool) {
for _, addr := range acc.addresses {
if addr.email == email {
return addr, true
}
}
return nil, false
}

49
server/backend/address.go Normal file
View File

@@ -0,0 +1,49 @@
package backend
import (
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
)
type address struct {
addrID string
email string
order int
keys []key
}
func (add *address) toAddress() proton.Address {
return proton.Address{
ID: add.addrID,
Email: add.email,
Send: true,
Receive: true,
Status: proton.AddressStatusEnabled,
Order: add.order,
DisplayName: add.email,
Keys: xslices.Map(add.keys, func(key key) proton.Key {
privKey, err := crypto.NewKeyFromArmored(key.key)
if err != nil {
panic(err)
}
rawKey, err := privKey.Serialize()
if err != nil {
panic(err)
}
return proton.Key{
ID: key.keyID,
PrivateKey: rawKey,
Token: key.tok,
Signature: key.sig,
Primary: key == add.keys[0],
Active: true,
}
}),
}
}

772
server/backend/api.go Normal file
View File

@@ -0,0 +1,772 @@
package backend
import (
"encoding/base64"
"errors"
"fmt"
"net/mail"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
func (b *Backend) GetUser(userID string) (proton.User, error) {
return withAcc(b, userID, func(acc *account) (proton.User, error) {
return acc.toUser(), nil
})
}
func (b *Backend) GetKeySalts(userID string) ([]proton.Salt, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Salt, error) {
return xslices.Map(acc.keys, func(key key) proton.Salt {
return proton.Salt{
ID: key.keyID,
KeySalt: base64.StdEncoding.EncodeToString(acc.salt),
}
}), nil
})
}
func (b *Backend) GetMailSettings(userID string) (proton.MailSettings, error) {
return withAcc(b, userID, func(acc *account) (proton.MailSettings, error) {
return proton.MailSettings{
DisplayName: acc.username,
DraftMIMEType: rfc822.TextHTML,
}, nil
})
}
func (b *Backend) GetAddressID(email string) (string, error) {
return withAccEmail(b, email, func(acc *account) (string, error) {
addr, ok := acc.getAddr(email)
if !ok {
return "", fmt.Errorf("no such address: %s", email)
}
return addr.addrID, nil
})
}
func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Address, error) {
return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address {
return add.toAddress()
}), nil
})
}
func (b *Backend) SetAddressOrder(userID string, addrIDs []string) error {
return b.withAcc(userID, func(acc *account) error {
for i, addrID := range addrIDs {
if add, ok := acc.addresses[addrID]; ok {
add.order = i + 1
} else {
return fmt.Errorf("no such address: %s", addrID)
}
}
return nil
})
}
func (b *Backend) HasLabel(userID, labelName string) (string, bool, error) {
labels, err := b.GetLabels(userID)
if err != nil {
return "", false, err
}
for _, label := range labels {
if label.Name == labelName {
return label.ID, true, nil
}
}
return "", false, nil
}
func (b *Backend) GetLabel(userID, labelID string) (proton.Label, error) {
labels, err := b.GetLabels(userID)
if err != nil {
return proton.Label{}, err
}
for _, label := range labels {
if label.ID == labelID {
return label, nil
}
}
return proton.Label{}, fmt.Errorf("no such label: %s", labelID)
}
func (b *Backend) GetLabels(userID string, types ...proton.LabelType) ([]proton.Label, error) {
return withAcc(b, userID, func(acc *account) ([]proton.Label, error) {
return withLabels(b, func(labels map[string]*label) ([]proton.Label, error) {
res := xslices.Map(acc.labelIDs, func(labelID string) proton.Label {
return labels[labelID].toLabel(labels)
})
for labelName, labelID := range map[string]string{
"Inbox": proton.InboxLabel,
"AllDrafts": proton.AllDraftsLabel,
"AllSent": proton.AllSentLabel,
"Trash": proton.TrashLabel,
"Spam": proton.SpamLabel,
"All Mail": proton.AllMailLabel,
"Archive": proton.ArchiveLabel,
"Sent": proton.SentLabel,
"Drafts": proton.DraftsLabel,
"Outbox": proton.OutboxLabel,
"Starred": proton.StarredLabel,
} {
res = append(res, proton.Label{
ID: labelID,
Name: labelName,
Path: []string{labelName},
Type: proton.LabelTypeSystem,
})
}
if len(types) > 0 {
res = xslices.Filter(res, func(label proton.Label) bool {
return slices.Contains(types, label.Type)
})
}
return res, nil
})
})
}
func (b *Backend) CreateLabel(userID, labelName, parentID string, labelType proton.LabelType) (proton.Label, error) {
return withAcc(b, userID, func(acc *account) (proton.Label, error) {
return withLabels(b, func(labels map[string]*label) (proton.Label, error) {
if parentID != "" {
if labelType != proton.LabelTypeFolder {
return proton.Label{}, fmt.Errorf("parentID can only be set for folders")
}
if _, ok := labels[parentID]; !ok {
return proton.Label{}, fmt.Errorf("no such parent label: %s", parentID)
}
}
label := newLabel(labelName, parentID, labelType)
labels[label.labelID] = label
updateID, err := b.newUpdate(&labelCreated{labelID: label.labelID})
if err != nil {
return proton.Label{}, err
}
acc.labelIDs = append(acc.labelIDs, label.labelID)
acc.updateIDs = append(acc.updateIDs, updateID)
return label.toLabel(labels), nil
})
})
}
func (b *Backend) UpdateLabel(userID, labelID, name, parentID string) (proton.Label, error) {
return withAcc(b, userID, func(acc *account) (proton.Label, error) {
return withLabels(b, func(labels map[string]*label) (proton.Label, error) {
if parentID != "" {
if labels[labelID].labelType != proton.LabelTypeFolder {
return proton.Label{}, fmt.Errorf("parentID can only be set for folders")
}
if _, ok := labels[parentID]; !ok {
return proton.Label{}, fmt.Errorf("no such parent label: %s", parentID)
}
}
labels[labelID].name = name
labels[labelID].parentID = parentID
updateID, err := b.newUpdate(&labelUpdated{labelID: labelID})
if err != nil {
return proton.Label{}, err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return labels[labelID].toLabel(labels), nil
})
})
}
func (b *Backend) DeleteLabel(userID, labelID string) error {
return b.withAcc(userID, func(acc *account) error {
return b.withLabels(func(labels map[string]*label) error {
if _, ok := labels[labelID]; !ok {
return errors.New("label not found")
}
for _, labelID := range getLabelIDsToDelete(labelID, labels) {
delete(labels, labelID)
updateID, err := b.newUpdate(&labelDeleted{labelID: labelID})
if err != nil {
return err
}
acc.labelIDs = xslices.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID })
acc.updateIDs = append(acc.updateIDs, updateID)
}
return nil
})
})
}
func (b *Backend) CountMessages(userID string) (int, error) {
return withAcc(b, userID, func(acc *account) (int, error) {
return len(acc.messageIDs), nil
})
}
func (b *Backend) GetMessageIDs(userID string, afterID string, limit int) ([]string, error) {
return withAcc(b, userID, func(acc *account) ([]string, error) {
if len(acc.messageIDs) == 0 {
return nil, nil
}
var lo, hi int
if afterID == "" {
lo = 0
} else {
lo = slices.Index(acc.messageIDs, afterID) + 1
}
if limit == 0 {
hi = len(acc.messageIDs)
} else {
hi = lo + limit
if hi > len(acc.messageIDs) {
hi = len(acc.messageIDs)
}
}
return acc.messageIDs[lo:hi], nil
})
}
func (b *Backend) GetMessages(userID string, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error) {
return withAcc(b, userID, func(acc *account) ([]proton.MessageMetadata, error) {
return withMessages(b, func(messages map[string]*message) ([]proton.MessageMetadata, error) {
if len(acc.messageIDs) == 0 {
return nil, nil
}
metadata := xslices.Map(xslices.Chunk(acc.messageIDs, pageSize)[page], func(messageID string) proton.MessageMetadata {
return messages[messageID].toMetadata()
})
if len(filter.ID) > 0 {
metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool {
return slices.Contains(filter.ID, metadata.ID)
})
}
if len(filter.AddressID) != 0 {
metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool {
return filter.AddressID == metadata.AddressID
})
}
if len(filter.ExternalID) != 0 {
metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool {
return filter.ExternalID != metadata.ExternalID
})
}
if len(filter.LabelID) != 0 {
metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool {
return slices.Contains(metadata.LabelIDs, filter.LabelID)
})
}
return metadata, nil
})
})
}
func (b *Backend) GetMessage(userID, messageID string) (proton.Message, error) {
return withAcc(b, userID, func(acc *account) (proton.Message, error) {
return withMessages(b, func(messages map[string]*message) (proton.Message, error) {
return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) {
message, ok := messages[messageID]
if !ok {
return proton.Message{}, errors.New("no such message")
}
return message.toMessage(atts), nil
})
})
})
}
func (b *Backend) SetMessagesRead(userID string, read bool, messageIDs ...string) error {
return b.withAcc(userID, func(acc *account) error {
return b.withMessages(func(messages map[string]*message) error {
for _, messageID := range messageIDs {
messages[messageID].unread = !read
updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
}
return nil
})
})
}
func (b *Backend) LabelMessages(userID, labelID string, messageIDs ...string) error {
if labelID == proton.AllMailLabel || labelID == proton.AllDraftsLabel || labelID == proton.AllSentLabel {
return fmt.Errorf("not allowed")
}
return b.withAcc(userID, func(acc *account) error {
return b.withMessages(func(messages map[string]*message) error {
return b.withLabels(func(labels map[string]*label) error {
for _, messageID := range messageIDs {
message, ok := messages[messageID]
if !ok {
continue
}
message.addLabel(labelID, labels)
updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
}
return nil
})
})
})
}
func (b *Backend) UnlabelMessages(userID, labelID string, messageIDs ...string) error {
if labelID == proton.AllMailLabel || labelID == proton.AllDraftsLabel || labelID == proton.AllSentLabel {
return fmt.Errorf("not allowed")
}
return b.withAcc(userID, func(acc *account) error {
return b.withMessages(func(messages map[string]*message) error {
return b.withLabels(func(labels map[string]*label) error {
for _, messageID := range messageIDs {
messages[messageID].remLabel(labelID, labels)
updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
}
return nil
})
})
})
}
func (b *Backend) DeleteMessage(userID, messageID string) error {
return b.withAcc(userID, func(acc *account) error {
return b.withMessages(func(messages map[string]*message) error {
message, ok := messages[messageID]
if !ok {
return errors.New("no such message")
}
for _, attID := range message.attIDs {
if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool {
return att.attDataID == b.attachments[attID].attDataID
}) == 1 {
delete(b.attData, b.attachments[attID].attDataID)
}
delete(b.attachments, attID)
}
delete(b.messages, messageID)
updateID, err := b.newUpdate(&messageDeleted{messageID: messageID})
if err != nil {
return err
}
acc.messageIDs = xslices.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID })
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
})
}
func (b *Backend) CreateDraft(
userID, addrID string,
subject string,
sender *mail.Address,
toList, ccList, bccList []*mail.Address,
armBody string,
mimeType rfc822.MIMEType,
externalID string,
) (proton.Message, error) {
return withAcc(b, userID, func(acc *account) (proton.Message, error) {
return withMessages(b, func(messages map[string]*message) (proton.Message, error) {
msg := newMessage(addrID, subject, sender, toList, ccList, bccList, armBody, mimeType, externalID)
messages[msg.messageID] = msg
updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID})
if err != nil {
return proton.Message{}, err
}
acc.messageIDs = append(acc.messageIDs, msg.messageID)
acc.updateIDs = append(acc.updateIDs, updateID)
return msg.toMessage(nil), nil
})
})
}
func (b *Backend) SendMessage(userID, messageID string, packages []*proton.MessagePackage) (proton.Message, error) {
return withAcc(b, userID, func(acc *account) (proton.Message, error) {
return withMessages(b, func(messages map[string]*message) (proton.Message, error) {
return withLabels(b, func(labels map[string]*label) (proton.Message, error) {
return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) {
msg := messages[messageID]
msg.flags |= proton.MessageFlagSent
msg.addLabel(proton.SentLabel, labels)
updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
if err != nil {
return proton.Message{}, err
}
acc.updateIDs = append(acc.updateIDs, updateID)
for _, pkg := range packages {
bodyData, err := base64.StdEncoding.DecodeString(pkg.Body)
if err != nil {
return proton.Message{}, err
}
for email, recipient := range pkg.Addresses {
if recipient.Type != proton.InternalScheme {
continue
}
if err := b.withAccEmail(email, func(acc *account) error {
bodyKey, err := base64.StdEncoding.DecodeString(recipient.BodyKeyPacket)
if err != nil {
return err
}
armBody, err := crypto.NewPGPSplitMessage(bodyKey, bodyData).GetPGPMessage().GetArmored()
if err != nil {
return err
}
addrID, err := b.GetAddressID(email)
if err != nil {
return err
}
newMsg := newMessage(
addrID,
msg.subject,
msg.sender,
msg.toList,
msg.ccList,
nil, // BCC is not sent to the recipient
armBody,
msg.mimeType,
msg.externalID,
)
newMsg.flags |= proton.MessageFlagReceived
newMsg.addLabel(proton.InboxLabel, labels)
newMsg.unread = true
messages[newMsg.messageID] = newMsg
for _, attID := range msg.attIDs {
attKey, err := base64.StdEncoding.DecodeString(recipient.AttachmentKeyPackets[attID])
if err != nil {
return err
}
att := newAttachment(
atts[attID].filename,
atts[attID].mimeType,
atts[attID].disposition,
attKey,
atts[attID].attDataID,
atts[attID].armSig,
)
atts[att.attachID] = att
messages[newMsg.messageID].attIDs = append(messages[newMsg.messageID].attIDs, att.attachID)
}
updateID, err := b.newUpdate(&messageCreated{messageID: newMsg.messageID})
if err != nil {
return err
}
acc.messageIDs = append(acc.messageIDs, newMsg.messageID)
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
}); err != nil {
return proton.Message{}, err
}
}
}
return msg.toMessage(atts), nil
})
})
})
})
}
func (b *Backend) CreateAttachment(
userID string,
messageID string,
filename string,
mimeType rfc822.MIMEType,
disposition proton.Disposition,
keyPackets, dataPacket []byte,
armSig string,
) (proton.Attachment, error) {
return withAcc(b, userID, func(acc *account) (proton.Attachment, error) {
return withMessages(b, func(messages map[string]*message) (proton.Attachment, error) {
return withAtts(b, func(atts map[string]*attachment) (proton.Attachment, error) {
att := newAttachment(
filename,
mimeType,
disposition,
keyPackets,
b.createAttData(dataPacket),
armSig,
)
atts[att.attachID] = att
messages[messageID].attIDs = append(messages[messageID].attIDs, att.attachID)
updateID, err := b.newUpdate(&messageUpdated{messageID: messageID})
if err != nil {
return proton.Attachment{}, err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return att.toAttachment(), nil
})
})
})
}
func (b *Backend) GetAttachment(attachID string) ([]byte, error) {
return withAtts(b, func(atts map[string]*attachment) ([]byte, error) {
att, ok := atts[attachID]
if !ok {
return nil, fmt.Errorf("no such attachment: %s", attachID)
}
return b.attData[att.attDataID], nil
})
}
func (b *Backend) GetLatestEventID(userID string) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
return acc.updateIDs[len(acc.updateIDs)-1].String(), nil
})
}
func (b *Backend) GetEvent(userID, rawEventID string) (proton.Event, error) {
var eventID ID
if err := eventID.FromString(rawEventID); err != nil {
return proton.Event{}, fmt.Errorf("invalid event ID: %s", rawEventID)
}
return withAcc(b, userID, func(acc *account) (proton.Event, error) {
return withMessages(b, func(messages map[string]*message) (proton.Event, error) {
return withLabels(b, func(labels map[string]*label) (proton.Event, error) {
updates, err := withUpdates(b, func(updates map[ID]update) ([]update, error) {
return merge(xslices.Map(acc.updateIDs[xslices.Index(acc.updateIDs, eventID)+1:], func(updateID ID) update {
return updates[updateID]
})), nil
})
if err != nil {
return proton.Event{}, fmt.Errorf("failed to merge updates: %w", err)
}
return buildEvent(updates, acc.addresses, messages, labels, acc.updateIDs[len(acc.updateIDs)-1].String()), nil
})
})
})
}
func (b *Backend) GetPublicKeys(email string) ([]proton.PublicKey, error) {
return withAccEmail(b, email, func(acc *account) ([]proton.PublicKey, error) {
var keys []proton.PublicKey
for _, addr := range acc.addresses {
if addr.email == email {
for _, key := range addr.keys {
pubKey, err := key.getPubKey()
if err != nil {
return nil, err
}
armKey, err := pubKey.GetArmoredPublicKey()
if err != nil {
return nil, err
}
keys = append(keys, proton.PublicKey{
Flags: proton.KeyStateTrusted | proton.KeyStateActive,
PublicKey: armKey,
})
}
}
}
return keys, nil
})
}
func getLabelIDsToDelete(labelID string, labels map[string]*label) []string {
labelIDs := []string{labelID}
for _, label := range labels {
if label.parentID == labelID {
labelIDs = append(labelIDs, getLabelIDsToDelete(label.labelID, labels)...)
}
}
return labelIDs
}
func buildEvent(
updates []update,
addresses map[string]*address,
messages map[string]*message,
labels map[string]*label,
eventID string,
) proton.Event {
event := proton.Event{EventID: eventID}
for _, update := range updates {
switch update := update.(type) {
case *userRefreshed:
event.Refresh = update.refresh
case *messageCreated:
event.Messages = append(event.Messages, proton.MessageEvent{
EventItem: proton.EventItem{
ID: update.messageID,
Action: proton.EventCreate,
},
Message: messages[update.messageID].toMetadata(),
})
case *messageUpdated:
event.Messages = append(event.Messages, proton.MessageEvent{
EventItem: proton.EventItem{
ID: update.messageID,
Action: proton.EventUpdate,
},
Message: messages[update.messageID].toMetadata(),
})
case *messageDeleted:
event.Messages = append(event.Messages, proton.MessageEvent{
EventItem: proton.EventItem{
ID: update.messageID,
Action: proton.EventDelete,
},
})
case *labelCreated:
event.Labels = append(event.Labels, proton.LabelEvent{
EventItem: proton.EventItem{
ID: update.labelID,
Action: proton.EventCreate,
},
Label: labels[update.labelID].toLabel(labels),
})
case *labelUpdated:
event.Labels = append(event.Labels, proton.LabelEvent{
EventItem: proton.EventItem{
ID: update.labelID,
Action: proton.EventUpdate,
},
Label: labels[update.labelID].toLabel(labels),
})
case *labelDeleted:
event.Labels = append(event.Labels, proton.LabelEvent{
EventItem: proton.EventItem{
ID: update.labelID,
Action: proton.EventDelete,
},
})
case *addressCreated:
event.Addresses = append(event.Addresses, proton.AddressEvent{
EventItem: proton.EventItem{
ID: update.addressID,
Action: proton.EventCreate,
},
Address: addresses[update.addressID].toAddress(),
})
case *addressUpdated:
event.Addresses = append(event.Addresses, proton.AddressEvent{
EventItem: proton.EventItem{
ID: update.addressID,
Action: proton.EventCreate,
},
Address: addresses[update.addressID].toAddress(),
})
case *addressDeleted:
event.Addresses = append(event.Addresses, proton.AddressEvent{
EventItem: proton.EventItem{
ID: update.addressID,
Action: proton.EventDelete,
},
})
}
}
return event
}

127
server/backend/api_auth.go Normal file
View File

@@ -0,0 +1,127 @@
package backend
import (
"encoding/base64"
"fmt"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-srp"
"github.com/google/uuid"
)
func (b *Backend) NewAuthInfo(username string) (proton.AuthInfo, error) {
return withAccName(b, username, func(acc *account) (proton.AuthInfo, error) {
server, err := srp.NewServerFromSigned(modulus, acc.verifier, 2048)
if err != nil {
return proton.AuthInfo{}, nil
}
challenge, err := server.GenerateChallenge()
if err != nil {
return proton.AuthInfo{}, nil
}
session := uuid.NewString()
b.srpLock.Lock()
defer b.srpLock.Unlock()
b.srp[session] = server
return proton.AuthInfo{
Version: 4,
Modulus: modulus,
ServerEphemeral: base64.StdEncoding.EncodeToString(challenge),
Salt: base64.StdEncoding.EncodeToString(acc.salt),
SRPSession: session,
TwoFA: proton.TwoFAInfo{Enabled: proton.TwoFADisabled},
}, nil
})
}
func (b *Backend) NewAuth(username string, ephemeral, proof []byte, session string) (proton.Auth, error) {
return withAccName(b, username, func(acc *account) (proton.Auth, error) {
b.srpLock.Lock()
defer b.srpLock.Unlock()
server, ok := b.srp[session]
if !ok {
return proton.Auth{}, fmt.Errorf("invalid session")
}
delete(b.srp, session)
serverProof, err := server.VerifyProofs(ephemeral, proof)
if !ok {
return proton.Auth{}, fmt.Errorf("invalid proof: %w", err)
}
authUID, auth := uuid.NewString(), newAuth(b.authLife)
acc.authLock.Lock()
defer acc.authLock.Unlock()
acc.auth[authUID] = auth
return auth.toAuth(acc.userID, authUID, serverProof), nil
})
}
func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) {
b.accLock.RLock()
defer b.accLock.RUnlock()
for _, acc := range b.accounts {
acc.authLock.Lock()
defer acc.authLock.Unlock()
auth, ok := acc.auth[authUID]
if !ok {
continue
}
if auth.ref != authRef {
return proton.Auth{}, fmt.Errorf("invalid auth ref")
}
newAuth := newAuth(b.authLife)
acc.auth[authUID] = newAuth
return newAuth.toAuth(acc.userID, authUID, nil), nil
}
return proton.Auth{}, fmt.Errorf("invalid auth")
}
func (b *Backend) VerifyAuth(authUID, authAcc string) (string, error) {
return withAccAuth(b, authUID, authAcc, func(acc *account) (string, error) {
return acc.userID, nil
})
}
func (b *Backend) GetSessions(userID string) ([]proton.AuthSession, error) {
return withAcc(b, userID, func(acc *account) ([]proton.AuthSession, error) {
acc.authLock.RLock()
defer acc.authLock.RUnlock()
var sessions []proton.AuthSession
for authUID, auth := range acc.auth {
sessions = append(sessions, auth.toAuthSession(authUID))
}
return sessions, nil
})
}
func (b *Backend) DeleteSession(userID, authUID string) error {
return b.withAcc(userID, func(acc *account) error {
acc.authLock.Lock()
defer acc.authLock.Unlock()
delete(acc.auth, authUID)
return nil
})
}

View File

@@ -0,0 +1,66 @@
package backend
import (
"encoding/base64"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/google/uuid"
)
func (b *Backend) createAttData(dataPacket []byte) string {
attDataID := uuid.NewString()
b.attDataLock.Lock()
defer b.attDataLock.Unlock()
b.attData[attDataID] = dataPacket
return attDataID
}
type attachment struct {
attachID string
attDataID string
filename string
mimeType rfc822.MIMEType
disposition proton.Disposition
keyPackets []byte
armSig string
}
func newAttachment(
filename string,
mimeType rfc822.MIMEType,
disposition proton.Disposition,
keyPackets []byte,
dataPacketID string,
armSig string,
) *attachment {
return &attachment{
attachID: uuid.NewString(),
attDataID: dataPacketID,
filename: filename,
mimeType: mimeType,
disposition: disposition,
keyPackets: keyPackets,
armSig: armSig,
}
}
func (att *attachment) toAttachment() proton.Attachment {
return proton.Attachment{
ID: att.attachID,
Name: att.filename,
MIMEType: att.mimeType,
Disposition: att.disposition,
KeyPackets: base64.StdEncoding.EncodeToString(att.keyPackets),
Signature: att.armSig,
}
}

544
server/backend/backend.go Normal file
View File

@@ -0,0 +1,544 @@
package backend
import (
"fmt"
"net/mail"
"sync"
"time"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-srp"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"golang.org/x/exp/maps"
)
type Backend struct {
accounts map[string]*account
accLock sync.RWMutex
attachments map[string]*attachment
attLock sync.Mutex
attData map[string][]byte
attDataLock sync.Mutex
messages map[string]*message
msgLock sync.Mutex
labels map[string]*label
lblLock sync.Mutex
updates map[ID]update
updatesLock sync.RWMutex
srp map[string]*srp.Server
srpLock sync.Mutex
authLife time.Duration
}
func New(authLife time.Duration) *Backend {
return &Backend{
accounts: make(map[string]*account),
attachments: make(map[string]*attachment),
attData: make(map[string][]byte),
messages: make(map[string]*message),
labels: make(map[string]*label),
updates: make(map[ID]update),
srp: make(map[string]*srp.Server),
authLife: authLife,
}
}
func (b *Backend) SetAuthLife(authLife time.Duration) {
b.authLife = authLife
}
func (b *Backend) CreateUser(username string, password []byte) (string, error) {
b.accLock.Lock()
defer b.accLock.Unlock()
salt, err := crypto.RandomToken(16)
if err != nil {
return "", err
}
passphrase, err := hashPassword(password, salt)
if err != nil {
return "", err
}
srpAuth, err := srp.NewAuthForVerifier(password, modulus, salt)
if err != nil {
return "", err
}
verifier, err := srpAuth.GenerateVerifier(2048)
if err != nil {
return "", err
}
armKey, err := GenerateKey(username, username, passphrase, "rsa", 2048)
if err != nil {
return "", err
}
userID := uuid.NewString()
b.accounts[userID] = newAccount(userID, username, armKey, salt, verifier)
return userID, nil
}
func (b *Backend) RemoveUser(userID string) error {
b.accLock.Lock()
defer b.accLock.Unlock()
user, ok := b.accounts[userID]
if !ok {
return fmt.Errorf("user %s does not exist", userID)
}
for _, labelID := range user.labelIDs {
delete(b.labels, labelID)
}
for _, messageID := range user.messageIDs {
for _, attID := range b.messages[messageID].attIDs {
if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool {
return att.attDataID == b.attachments[attID].attDataID
}) == 1 {
delete(b.attData, b.attachments[attID].attDataID)
}
delete(b.attachments, attID)
}
delete(b.messages, messageID)
}
delete(b.accounts, userID)
return nil
}
func (b *Backend) RefreshUser(userID string, refresh proton.RefreshFlag) error {
return b.withAcc(userID, func(acc *account) error {
updateID, err := b.newUpdate(&userRefreshed{refresh: refresh})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) CreateUserKey(userID string, password []byte) error {
b.accLock.Lock()
defer b.accLock.Unlock()
user, ok := b.accounts[userID]
if !ok {
return fmt.Errorf("user %s does not exist", userID)
}
salt, err := crypto.RandomToken(16)
if err != nil {
return err
}
passphrase, err := hashPassword(password, salt)
if err != nil {
return err
}
armKey, err := GenerateKey(user.username, user.username, passphrase, "rsa", 2048)
if err != nil {
return err
}
user.keys = append(user.keys, key{keyID: uuid.NewString(), key: armKey})
return nil
}
func (b *Backend) RemoveUserKey(userID, keyID string) error {
b.accLock.Lock()
defer b.accLock.Unlock()
user, ok := b.accounts[userID]
if !ok {
return fmt.Errorf("user %s does not exist", userID)
}
idx := xslices.IndexFunc(user.keys, func(key key) bool {
return key.keyID == keyID
})
if idx == -1 {
return fmt.Errorf("key %s does not exist", keyID)
}
user.keys = append(user.keys[:idx], user.keys[idx+1:]...)
return nil
}
func (b *Backend) CreateAddress(userID, email string, password []byte) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
token, err := crypto.RandomToken(32)
if err != nil {
return "", err
}
armKey, err := GenerateKey(acc.username, email, token, "rsa", 2048)
if err != nil {
return "", err
}
passphrase, err := hashPassword([]byte(password), acc.salt)
if err != nil {
return "", err
}
userKR, err := acc.keys[0].unlock(passphrase)
if err != nil {
return "", err
}
encToken, sigToken, err := encryptWithSignature(userKR, token)
if err != nil {
return "", err
}
addressID := uuid.NewString()
acc.addresses[addressID] = &address{
addrID: addressID,
email: email,
order: len(acc.addresses) + 1,
keys: []key{{
keyID: uuid.NewString(),
key: armKey,
tok: encToken,
sig: sigToken,
}},
}
updateID, err := b.newUpdate(&addressCreated{addressID: addressID})
if err != nil {
return "", err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return addressID, nil
})
}
func (b *Backend) CreateAddressKey(userID, addrID string, password []byte) error {
return b.withAcc(userID, func(acc *account) error {
token, err := crypto.RandomToken(32)
if err != nil {
return err
}
armKey, err := GenerateKey(acc.username, acc.addresses[addrID].email, token, "rsa", 2048)
if err != nil {
return err
}
passphrase, err := hashPassword([]byte(password), acc.salt)
if err != nil {
return err
}
userKR, err := acc.keys[0].unlock(passphrase)
if err != nil {
return err
}
encToken, sigToken, err := encryptWithSignature(userKR, token)
if err != nil {
return err
}
acc.addresses[addrID].keys = append(acc.addresses[addrID].keys, key{
keyID: uuid.NewString(),
key: armKey,
tok: encToken,
sig: sigToken,
})
updateID, err := b.newUpdate(&addressUpdated{addressID: addrID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) RemoveAddress(userID, addrID string) error {
return b.withAcc(userID, func(acc *account) error {
if _, ok := acc.addresses[addrID]; !ok {
return fmt.Errorf("address %s not found", addrID)
}
delete(acc.addresses, addrID)
updateID, err := b.newUpdate(&addressDeleted{addressID: addrID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) RemoveAddressKey(userID, addrID, keyID string) error {
return b.withAcc(userID, func(acc *account) error {
idx := xslices.IndexFunc(acc.addresses[addrID].keys, func(key key) bool {
return key.keyID == keyID
})
if idx < 0 {
return fmt.Errorf("key %s not found", keyID)
}
acc.addresses[addrID].keys = append(acc.addresses[addrID].keys[:idx], acc.addresses[addrID].keys[idx+1:]...)
updateID, err := b.newUpdate(&addressUpdated{addressID: addrID})
if err != nil {
return err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return nil
})
}
func (b *Backend) CreateMessage(
userID, addrID string,
subject string,
sender *mail.Address,
toList, ccList, bccList []*mail.Address,
armBody string,
mimeType rfc822.MIMEType,
flags proton.MessageFlag,
unread, starred bool,
) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
return withMessages(b, func(messages map[string]*message) (string, error) {
msg := newMessage(addrID, subject, sender, toList, ccList, bccList, armBody, mimeType, "")
msg.flags |= flags
msg.unread = unread
msg.starred = starred
messages[msg.messageID] = msg
updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID})
if err != nil {
return "", err
}
acc.messageIDs = append(acc.messageIDs, msg.messageID)
acc.updateIDs = append(acc.updateIDs, updateID)
return msg.messageID, nil
})
})
}
func (b *Backend) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
return withMessages(b, func(messages map[string]*message) (string, error) {
if _, ok := messages[draftID]; !ok {
return "", fmt.Errorf("message %q not found", draftID)
}
messages[draftID].applyChanges(changes)
updateID, err := b.newUpdate(&messageUpdated{messageID: draftID})
if err != nil {
return "", err
}
acc.updateIDs = append(acc.updateIDs, updateID)
return draftID, nil
})
})
}
func (b *Backend) Encrypt(userID, addrID, decBody string) (string, error) {
return withAcc(b, userID, func(acc *account) (string, error) {
pubKey, err := acc.addresses[addrID].keys[0].getPubKey()
if err != nil {
return "", err
}
kr, err := crypto.NewKeyRing(pubKey)
if err != nil {
return "", err
}
enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(decBody), nil)
if err != nil {
return "", err
}
return enc.GetArmored()
})
}
func (b *Backend) withAcc(userID string, fn func(acc *account) error) error {
b.accLock.RLock()
defer b.accLock.RUnlock()
acc, ok := b.accounts[userID]
if !ok {
return fmt.Errorf("account %s not found", userID)
}
return fn(acc)
}
func (b *Backend) withAccEmail(email string, fn func(acc *account) error) error {
b.accLock.RLock()
defer b.accLock.RUnlock()
for _, acc := range b.accounts {
for _, addr := range acc.addresses {
if addr.email == email {
return fn(acc)
}
}
}
return fmt.Errorf("account %s not found", email)
}
func withAcc[T any](b *Backend, userID string, fn func(acc *account) (T, error)) (T, error) {
b.accLock.RLock()
defer b.accLock.RUnlock()
for _, acc := range b.accounts {
if acc.userID == userID {
return fn(acc)
}
}
return *new(T), fmt.Errorf("account not found")
}
func withAccName[T any](b *Backend, username string, fn func(acc *account) (T, error)) (T, error) {
b.accLock.RLock()
defer b.accLock.RUnlock()
for _, acc := range b.accounts {
if acc.username == username {
return fn(acc)
}
}
return *new(T), fmt.Errorf("account not found")
}
func withAccEmail[T any](b *Backend, email string, fn func(acc *account) (T, error)) (T, error) {
b.accLock.RLock()
defer b.accLock.RUnlock()
for _, acc := range b.accounts {
if _, ok := acc.getAddr(email); ok {
return fn(acc)
}
}
return *new(T), fmt.Errorf("account not found")
}
func withAccAuth[T any](b *Backend, authUID, authAcc string, fn func(acc *account) (T, error)) (T, error) {
b.accLock.RLock()
defer b.accLock.RUnlock()
for _, acc := range b.accounts {
acc.authLock.RLock()
defer acc.authLock.RUnlock()
auth, ok := acc.auth[authUID]
if !ok {
continue
}
if auth.acc == authAcc {
return fn(acc)
}
}
return *new(T), fmt.Errorf("account not found")
}
func (b *Backend) withMessages(fn func(map[string]*message) error) error {
b.msgLock.Lock()
defer b.msgLock.Unlock()
return fn(b.messages)
}
func withMessages[T any](b *Backend, fn func(map[string]*message) (T, error)) (T, error) {
b.msgLock.Lock()
defer b.msgLock.Unlock()
return fn(b.messages)
}
func withAtts[T any](b *Backend, fn func(map[string]*attachment) (T, error)) (T, error) {
b.attLock.Lock()
defer b.attLock.Unlock()
return fn(b.attachments)
}
func (b *Backend) withLabels(fn func(map[string]*label) error) error {
b.lblLock.Lock()
defer b.lblLock.Unlock()
return fn(b.labels)
}
func withLabels[T any](b *Backend, fn func(map[string]*label) (T, error)) (T, error) {
b.lblLock.Lock()
defer b.lblLock.Unlock()
return fn(b.labels)
}
func (b *Backend) newUpdate(event update) (ID, error) {
return withUpdates(b, func(updates map[ID]update) (ID, error) {
updateID := ID(len(updates))
updates[updateID] = event
return updateID, nil
})
}
func withUpdates[T any](b *Backend, fn func(map[ID]update) (T, error)) (T, error) {
b.updatesLock.Lock()
defer b.updatesLock.Unlock()
return fn(b.updates)
}

42
server/backend/crypto.go Normal file
View File

@@ -0,0 +1,42 @@
package backend
import (
"github.com/ProtonMail/go-srp"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/gopenpgp/v2/helper"
)
var GenerateKey = helper.GenerateKey
func hashPassword(password, salt []byte) ([]byte, error) {
passphrase, err := srp.MailboxPassword(password, salt)
if err != nil {
return nil, err
}
return passphrase[len(passphrase)-31:], nil
}
func encryptWithSignature(kr *crypto.KeyRing, b []byte) (string, string, error) {
enc, err := kr.Encrypt(crypto.NewPlainMessage(b), nil)
if err != nil {
return "", "", err
}
encArm, err := enc.GetArmored()
if err != nil {
return "", "", err
}
sig, err := kr.SignDetached(crypto.NewPlainMessage(b))
if err != nil {
return "", "", err
}
sigArm, err := sig.GetArmored()
if err != nil {
return "", "", err
}
return encArm, sigArm, nil
}

39
server/backend/label.go Normal file
View File

@@ -0,0 +1,39 @@
package backend
import (
"github.com/ProtonMail/go-proton-api"
"github.com/google/uuid"
)
type label struct {
labelID string
parentID string
name string
labelType proton.LabelType
messageIDs map[string]struct{}
}
func newLabel(labelName, parentID string, labelType proton.LabelType) *label {
return &label{
labelID: uuid.NewString(),
parentID: parentID,
name: labelName,
labelType: labelType,
messageIDs: make(map[string]struct{}),
}
}
func (label *label) toLabel(labels map[string]*label) proton.Label {
var path []string
for labelID := label.labelID; labelID != ""; labelID = labels[labelID].parentID {
path = append([]string{labels[labelID].name}, path...)
}
return proton.Label{
ID: label.labelID,
Name: label.name,
Path: path,
Type: label.labelType,
}
}

300
server/backend/message.go Normal file
View File

@@ -0,0 +1,300 @@
package backend
import (
"net/mail"
"strings"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"golang.org/x/exp/slices"
)
type message struct {
messageID string
externalID string
addrID string
labelIDs []string
sysLabel *string
attIDs []string
subject string
sender *mail.Address
toList []*mail.Address
ccList []*mail.Address
bccList []*mail.Address
armBody string
mimeType rfc822.MIMEType
flags proton.MessageFlag
unread bool
starred bool
}
func newMessage(
addrID string,
subject string,
sender *mail.Address,
toList, ccList, bccList []*mail.Address,
armBody string,
mimeType rfc822.MIMEType,
externalID string,
) *message {
return &message{
messageID: uuid.NewString(),
externalID: externalID,
addrID: addrID,
sysLabel: pointer(""),
subject: subject,
sender: sender,
toList: toList,
ccList: ccList,
bccList: bccList,
armBody: armBody,
mimeType: mimeType,
}
}
func (msg *message) toMessage(att map[string]*attachment) proton.Message {
return proton.Message{
MessageMetadata: msg.toMetadata(),
Header: msg.getHeader(),
ParsedHeaders: msg.getParsedHeaders(),
Body: msg.armBody,
MIMEType: msg.mimeType,
Attachments: xslices.Map(msg.attIDs, func(attID string) proton.Attachment {
return att[attID].toAttachment()
}),
}
}
func (msg *message) toMetadata() proton.MessageMetadata {
labelIDs := []string{proton.AllMailLabel}
if msg.flags.Has(proton.MessageFlagSent) {
labelIDs = append(labelIDs, proton.AllSentLabel)
}
if !msg.flags.HasAny(proton.MessageFlagSent, proton.MessageFlagReceived) {
labelIDs = append(labelIDs, proton.AllDraftsLabel)
}
if msg.starred {
labelIDs = append(labelIDs, proton.StarredLabel)
}
if msg.sysLabel != nil {
if *msg.sysLabel != "" {
labelIDs = append(labelIDs, *msg.sysLabel)
}
} else {
switch {
case msg.flags.Has(proton.MessageFlagReceived):
labelIDs = append(labelIDs, proton.InboxLabel)
case msg.flags.Has(proton.MessageFlagSent):
labelIDs = append(labelIDs, proton.SentLabel)
default:
labelIDs = append(labelIDs, proton.DraftsLabel)
}
}
return proton.MessageMetadata{
ID: msg.messageID,
ExternalID: msg.externalID,
AddressID: msg.addrID,
LabelIDs: append(msg.labelIDs, labelIDs...),
Subject: msg.subject,
Sender: msg.sender,
ToList: msg.toList,
CCList: msg.ccList,
BCCList: msg.bccList,
Flags: msg.flags,
Unread: proton.Bool(msg.unread),
}
}
func (msg *message) getHeader() string {
builder := new(strings.Builder)
builder.WriteString("Subject: " + msg.subject + "\r\n")
if msg.sender != nil {
builder.WriteString("From: " + msg.sender.String() + "\r\n")
}
if len(msg.toList) > 0 {
builder.WriteString("To: " + toAddressList(msg.toList) + "\r\n")
}
if len(msg.ccList) > 0 {
builder.WriteString("Cc: " + toAddressList(msg.ccList) + "\r\n")
}
if len(msg.bccList) > 0 {
builder.WriteString("Bcc: " + toAddressList(msg.bccList) + "\r\n")
}
if msg.mimeType != "" {
builder.WriteString("Content-Type: " + string(msg.mimeType) + "\r\n")
}
return builder.String()
}
func (msg *message) getParsedHeaders() proton.Headers {
header, err := rfc822.NewHeader([]byte(msg.getHeader()))
if err != nil {
panic(err)
}
parsed := make(proton.Headers)
header.Entries(func(key, value string) {
parsed[key] = append(parsed[key], value)
})
return parsed
}
// applyChanges will apply non-nil field from passed message.
//
// NOTE: This is not feature complete. It might panic on non-implemented
// changes.
func (msg *message) applyChanges(changes proton.DraftTemplate) {
if changes.Subject != "" {
msg.subject = changes.Subject
}
if changes.Sender != nil {
panic("sender change probably not allowed by API on existing draft")
}
if changes.ToList != nil {
msg.toList = append([]*mail.Address{}, changes.ToList...)
}
if changes.CCList != nil {
msg.ccList = append([]*mail.Address{}, changes.CCList...)
}
if changes.BCCList != nil {
msg.bccList = append([]*mail.Address{}, changes.BCCList...)
}
if changes.Body != "" {
msg.armBody = changes.Body
}
if changes.MIMEType != "" {
msg.mimeType = changes.MIMEType
}
if changes.ExternalID != "" {
msg.externalID = changes.ExternalID
}
}
func (msg *message) addLabel(labelID string, labels map[string]*label) {
switch labelID {
case proton.InboxLabel, proton.SentLabel, proton.DraftsLabel:
msg.addFlagLabel(labelID, labels)
case proton.TrashLabel, proton.SpamLabel, proton.ArchiveLabel:
msg.addSystemLabel(labelID, labels)
case proton.StarredLabel:
msg.starred = true
default:
if label, ok := labels[labelID]; ok {
msg.addUserLabel(label, labels)
}
}
}
func (msg *message) addFlagLabel(labelID string, labels map[string]*label) {
msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool {
return labels[otherLabelID].labelType == proton.LabelTypeLabel
})
msg.sysLabel = nil
}
func (msg *message) addSystemLabel(labelID string, labels map[string]*label) {
msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool {
return labels[otherLabelID].labelType == proton.LabelTypeLabel
})
msg.sysLabel = &labelID
}
func (msg *message) addUserLabel(label *label, labels map[string]*label) {
if label.labelType != proton.LabelTypeLabel {
msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool {
return labels[otherLabelID].labelType == proton.LabelTypeLabel
})
msg.sysLabel = pointer("")
}
if !slices.Contains(msg.labelIDs, label.labelID) {
msg.labelIDs = append(msg.labelIDs, label.labelID)
}
}
func (msg *message) remLabel(labelID string, labels map[string]*label) {
switch labelID {
case proton.InboxLabel, proton.SentLabel, proton.DraftsLabel:
msg.remFlagLabel(labelID, labels)
case proton.TrashLabel, proton.SpamLabel, proton.ArchiveLabel:
msg.remSystemLabel(labelID, labels)
case proton.StarredLabel:
msg.starred = false
default:
if label, ok := labels[labelID]; ok {
msg.remUserLabel(label, labels)
}
}
}
func (msg *message) remFlagLabel(labelID string, labels map[string]*label) {
msg.sysLabel = pointer("")
}
func (msg *message) remSystemLabel(labelID string, labels map[string]*label) {
if msg.sysLabel != nil && *msg.sysLabel == labelID {
msg.sysLabel = pointer("")
}
}
func (msg *message) remUserLabel(label *label, labels map[string]*label) {
msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool {
return otherLabelID != label.labelID
})
}
func toAddressList(addrs []*mail.Address) string {
res := make([]string, len(addrs))
for i, addr := range addrs {
res[i] = addr.String()
}
return strings.Join(res, ", ")
}
func pointer[T any](v T) *T {
return &v
}

View File

@@ -0,0 +1 @@
+88jb48lF5TyDBveyHZ7QhSvtc4V3pN8/eQW6kk6ok2egy4lr5Wz9h8iZP3erN9lReSx1Lk+WsLu1b3soDhXX/twTCUhxYwjS8r983aEshZJJq7p5tNroQ5pzrZMbK8Oszjajgdg2YzcMcaJqb9+Doi7egj/esUQ+Q7BWdxeK77Wafj9v7PiW6Ozx6ulppu1mZ+YGnXSXJsl1Cl4nPm7PNkgj4BQT3HLrxakh7Xc3agmepRKO/1jLaOBU/oO17URbA5rwh/ZlAOqEAKH5vJ+hA2acM3Bwsa/K8I/jWicxOoaLZ4RZFpLYvOxGbb4DggR2Ri/C6tNyeEQQKAtxpeV5g==

24
server/backend/modulus.go Normal file
View File

@@ -0,0 +1,24 @@
package backend
import (
_ "embed"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
var modulus string
func init() {
arm, err := crypto.NewClearTextMessage(asc, sig).GetArmored()
if err != nil {
panic(err)
}
modulus = arm
}
//go:embed modulus.asc
var asc []byte
//go:embed modulus.sig
var sig []byte

BIN
server/backend/modulus.sig Normal file
View File

Binary file not shown.

112
server/backend/types.go Normal file
View File

@@ -0,0 +1,112 @@
package backend
import (
"encoding/base64"
"math/big"
"time"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/google/uuid"
)
type ID uint64
func (v ID) String() string {
return base64.StdEncoding.EncodeToString(v.Bytes())
}
func (v ID) Bytes() []byte {
if v == 0 {
return []byte{0}
}
return new(big.Int).SetUint64(uint64(v)).Bytes()
}
func (v *ID) FromString(s string) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return err
}
*v = ID(new(big.Int).SetBytes(b).Uint64())
return nil
}
type auth struct {
acc string
ref string
expiration time.Time
creation time.Time
}
func newAuth(authLife time.Duration) auth {
return auth{
acc: uuid.NewString(),
ref: uuid.NewString(),
expiration: time.Now().Add(authLife),
creation: time.Now(),
}
}
func (auth *auth) toAuth(userID, authUID string, proof []byte) proton.Auth {
return proton.Auth{
UserID: userID,
UID: authUID,
AccessToken: auth.acc,
RefreshToken: auth.ref,
ServerProof: base64.StdEncoding.EncodeToString(proof),
ExpiresIn: int(time.Until(auth.expiration).Seconds()),
TwoFA: proton.TwoFAInfo{Enabled: proton.TwoFADisabled},
PasswordMode: proton.OnePasswordMode,
}
}
func (auth *auth) toAuthSession(authUID string) proton.AuthSession {
return proton.AuthSession{
UID: authUID,
CreateTime: auth.creation.Unix(),
Revocable: true,
}
}
type key struct {
keyID string
key string
tok string
sig string
}
func (key key) unlock(passphrase []byte) (*crypto.KeyRing, error) {
lockedKey, err := crypto.NewKeyFromArmored(key.key)
if err != nil {
return nil, err
}
unlockedKey, err := lockedKey.Unlock(passphrase)
if err != nil {
return nil, err
}
return crypto.NewKeyRing(unlockedKey)
}
func (key key) getPubKey() (*crypto.Key, error) {
privKey, err := crypto.NewKeyFromArmored(key.key)
if err != nil {
return nil, err
}
pubKeyBin, err := privKey.GetPublicKey()
if err != nil {
return nil, err
}
return crypto.NewKey(pubKeyBin)
}

View File

@@ -0,0 +1,23 @@
package backend
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestID(t *testing.T) {
var v ID
// We can set the ID from a string.
require.NoError(t, v.FromString("AQIDBA=="))
// We can get the ID as a string.
require.Equal(t, "AQIDBA==", v.String())
// We can get the ID as bytes.
require.Equal(t, []byte{1, 2, 3, 4}, v.Bytes())
// The ID is correct.
require.Equal(t, ID(0x01020304), v)
}

175
server/backend/updates.go Normal file
View File

@@ -0,0 +1,175 @@
package backend
import (
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xslices"
)
func merge(updates []update) []update {
if len(updates) < 2 {
return updates
}
if merged := merge(updates[1:]); xslices.IndexFunc(merged, func(other update) bool {
return other.replaces(updates[0])
}) < 0 {
return append([]update{updates[0]}, merged...)
} else {
return merged
}
}
type update interface {
replaces(other update) bool
_isUpdate()
}
type baseUpdate struct{}
func (baseUpdate) replaces(update) bool {
return false
}
func (baseUpdate) _isUpdate() {}
type userRefreshed struct {
baseUpdate
refresh proton.RefreshFlag
}
type messageCreated struct {
baseUpdate
messageID string
}
type messageUpdated struct {
baseUpdate
messageID string
}
func (update *messageUpdated) replaces(other update) bool {
switch other := other.(type) {
case *messageUpdated:
return update.messageID == other.messageID
default:
return false
}
}
type messageDeleted struct {
baseUpdate
messageID string
}
func (update *messageDeleted) replaces(other update) bool {
switch other := other.(type) {
case *messageCreated:
return update.messageID == other.messageID
case *messageUpdated:
return update.messageID == other.messageID
case *messageDeleted:
if update.messageID != other.messageID {
return false
}
panic("message deleted twice")
default:
return false
}
}
type labelCreated struct {
baseUpdate
labelID string
}
type labelUpdated struct {
baseUpdate
labelID string
}
func (update *labelUpdated) replaces(other update) bool {
switch other := other.(type) {
case *labelUpdated:
return update.labelID == other.labelID
default:
return false
}
}
type labelDeleted struct {
baseUpdate
labelID string
}
func (update *labelDeleted) replaces(other update) bool {
switch other := other.(type) {
case *labelCreated:
return update.labelID == other.labelID
case *labelUpdated:
return update.labelID == other.labelID
case *labelDeleted:
if update.labelID != other.labelID {
return false
}
panic("label deleted twice")
default:
return false
}
}
type addressCreated struct {
baseUpdate
addressID string
}
type addressUpdated struct {
baseUpdate
addressID string
}
func (update *addressUpdated) replaces(other update) bool {
switch other := other.(type) {
case *addressUpdated:
return update.addressID == other.addressID
default:
return false
}
}
type addressDeleted struct {
baseUpdate
addressID string
}
func (update *addressDeleted) replaces(other update) bool {
switch other := other.(type) {
case *addressCreated:
return update.addressID == other.addressID
case *addressUpdated:
return update.addressID == other.addressID
case *addressDeleted:
if update.addressID != other.addressID {
return false
}
panic("address deleted twice")
default:
return false
}
}

View File

@@ -0,0 +1,63 @@
package backend
import (
"reflect"
"testing"
)
func Test_mergeUpdates(t *testing.T) {
tests := []struct {
name string
have []update
want []update
}{
{
name: "single",
have: []update{&labelCreated{labelID: "1"}},
want: []update{&labelCreated{labelID: "1"}},
},
{
name: "multiple",
have: []update{
&labelCreated{labelID: "1"},
&labelCreated{labelID: "2"},
},
want: []update{
&labelCreated{labelID: "1"},
&labelCreated{labelID: "2"},
},
},
{
name: "replace with updated",
have: []update{
&labelCreated{labelID: "1"},
&labelUpdated{labelID: "1"},
&labelUpdated{labelID: "1"},
},
want: []update{
&labelCreated{labelID: "1"},
&labelUpdated{labelID: "1"},
},
},
{
name: "replace with delete",
have: []update{
&labelCreated{labelID: "1"},
&labelUpdated{labelID: "1"},
&labelUpdated{labelID: "1"},
&labelDeleted{labelID: "1"},
},
want: []update{
&labelDeleted{labelID: "1"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := merge(tt.have); !reflect.DeepEqual(got, tt.want) {
t.Errorf("mergeUpdates() = %v, want %v", got, tt.want)
}
})
}
}

52
server/cache.go Normal file
View File

@@ -0,0 +1,52 @@
package server
import (
"sync"
"github.com/ProtonMail/go-proton-api"
)
func NewAuthCache() AuthCacher {
return &authCache{
info: make(map[string]proton.AuthInfo),
auth: make(map[string]proton.Auth),
}
}
type authCache struct {
info map[string]proton.AuthInfo
auth map[string]proton.Auth
lock sync.RWMutex
}
func (c *authCache) GetAuthInfo(username string) (proton.AuthInfo, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
info, ok := c.info[username]
return info, ok
}
func (c *authCache) SetAuthInfo(username string, info proton.AuthInfo) {
c.lock.Lock()
defer c.lock.Unlock()
c.info[username] = info
}
func (c *authCache) GetAuth(username string) (proton.Auth, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
auth, ok := c.auth[username]
return auth, ok
}
func (c *authCache) SetAuth(username string, auth proton.Auth) {
c.lock.Lock()
defer c.lock.Unlock()
c.auth[username] = auth
}

50
server/call.go Normal file
View File

@@ -0,0 +1,50 @@
package server
import (
"net/http"
"net/url"
)
type Call struct {
URL *url.URL
Method string
Status int
RequestHeader http.Header
RequestBody []byte
ResponseHeader http.Header
ResponseBody []byte
}
type callWatcher struct {
paths map[string]struct{}
callFn func(Call)
}
func newCallWatcher(fn func(Call), paths ...string) callWatcher {
pathMap := make(map[string]struct{}, len(paths))
for _, path := range paths {
pathMap[path] = struct{}{}
}
return callWatcher{
paths: pathMap,
callFn: fn,
}
}
func (watcher *callWatcher) isWatching(path string) bool {
if len(watcher.paths) == 0 {
return true
}
_, ok := watcher.paths[path]
return ok
}
func (watcher *callWatcher) publish(call Call) {
watcher.callFn(call)
}

291
server/cmd/client/client.go Normal file
View File

@@ -0,0 +1,291 @@
package main
import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"os"
"github.com/ProtonMail/go-proton-api/server/proto"
"github.com/urfave/cli/v2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func main() {
app := cli.NewApp()
app.Flags = []cli.Flag{
&cli.StringFlag{
Name: "host",
Usage: "host to connect to",
Value: "localhost",
},
&cli.IntFlag{
Name: "port",
Usage: "port to connect to",
Value: 8080,
},
}
app.Commands = []*cli.Command{
{
Name: "info",
Action: getInfoAction,
},
{
Name: "auth",
Subcommands: []*cli.Command{
{
Name: "revoke",
Action: revokeUserAction,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "userID",
Usage: "user ID to revoke",
Required: true,
},
},
},
},
},
{
Name: "user",
Subcommands: []*cli.Command{
{
Name: "create",
Action: createUserAction,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "username",
Usage: "username of the account",
Required: true,
},
&cli.StringFlag{
Name: "email",
Usage: "email of the account",
Required: true,
},
&cli.StringFlag{
Name: "password",
Usage: "password of the account",
Required: true,
},
},
},
},
},
{
Name: "address",
Subcommands: []*cli.Command{
{
Name: "create",
Action: createAddressAction,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "userID",
Usage: "ID of the user to create the address for",
Required: true,
},
&cli.StringFlag{
Name: "email",
Usage: "email of the account",
Required: true,
},
&cli.StringFlag{
Name: "password",
Usage: "password of the account",
Required: true,
},
},
},
{
Name: "remove",
Action: removeAddressAction,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "userID",
Usage: "ID of the user to remove the address from",
Required: true,
},
&cli.StringFlag{
Name: "addressID",
Usage: "ID of the address to remove",
Required: true,
},
},
},
},
},
{
Name: "label",
Subcommands: []*cli.Command{
{
Name: "create",
Action: createLabelAction,
Flags: []cli.Flag{
&cli.StringFlag{
Name: "userID",
Usage: "ID of the user to create the label for",
Required: true,
},
&cli.StringFlag{
Name: "name",
Usage: "name of the label",
Required: true,
},
&cli.StringFlag{
Name: "parentID",
Usage: "the ID of the parent label",
},
&cli.BoolFlag{
Name: "exclusive",
Usage: "Create an exclusive label (i.e. a folder)",
},
},
},
},
},
}
if err := app.Run(os.Args); err != nil {
log.Fatal(err)
}
}
func getInfoAction(c *cli.Context) error {
client, err := newServerClient(c)
if err != nil {
return err
}
res, err := client.GetInfo(c.Context, &proto.GetInfoRequest{})
if err != nil {
return err
}
return pretty(c.App.Writer, res)
}
func createUserAction(c *cli.Context) error {
client, err := newServerClient(c)
if err != nil {
return err
}
res, err := client.CreateUser(c.Context, &proto.CreateUserRequest{
Username: c.String("username"),
Email: c.String("email"),
Password: []byte(c.String("password")),
})
if err != nil {
return err
}
return pretty(c.App.Writer, res)
}
func revokeUserAction(c *cli.Context) error {
client, err := newServerClient(c)
if err != nil {
return err
}
res, err := client.RevokeUser(c.Context, &proto.RevokeUserRequest{
UserID: c.String("userID"),
})
if err != nil {
return err
}
return pretty(c.App.Writer, res)
}
func createAddressAction(c *cli.Context) error {
client, err := newServerClient(c)
if err != nil {
return err
}
res, err := client.CreateAddress(c.Context, &proto.CreateAddressRequest{
UserID: c.String("userID"),
Email: c.String("email"),
Password: []byte(c.String("password")),
})
if err != nil {
return err
}
return pretty(c.App.Writer, res)
}
func removeAddressAction(c *cli.Context) error {
client, err := newServerClient(c)
if err != nil {
return err
}
res, err := client.RemoveAddress(c.Context, &proto.RemoveAddressRequest{
UserID: c.String("userID"),
AddrID: c.String("addressID"),
})
if err != nil {
return err
}
return pretty(c.App.Writer, res)
}
func createLabelAction(c *cli.Context) error {
client, err := newServerClient(c)
if err != nil {
return err
}
var labelType proto.LabelType
if c.Bool("exclusive") {
labelType = proto.LabelType_FOLDER
} else {
labelType = proto.LabelType_LABEL
}
res, err := client.CreateLabel(c.Context, &proto.CreateLabelRequest{
UserID: c.String("userID"),
Name: c.String("name"),
Type: labelType,
})
if err != nil {
return err
}
return pretty(c.App.Writer, res)
}
func newServerClient(c *cli.Context) (proto.ServerClient, error) {
cc, err := grpc.DialContext(
c.Context,
net.JoinHostPort(c.String("host"), fmt.Sprint(c.Int("port"))),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
return proto.NewServerClient(cc), nil
}
func pretty[T any](w io.Writer, v T) error {
enc, err := json.MarshalIndent(v, "", " ")
if err != nil {
return err
}
if _, err := w.Write(enc); err != nil {
return err
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More