commit 2323ea736020d732c3eebf89abc6d669d11886b9 Author: James Houlahan Date: Wed Nov 23 11:17:54 2022 +0100 feat: Initial open source commit diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml new file mode 100644 index 0000000..10dc835 --- /dev/null +++ b/.github/workflows/check.yml @@ -0,0 +1,28 @@ +name: Lint and Test + +on: push + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Get sources + uses: actions/checkout@v3 + + - name: Set up Go 1.18 + uses: actions/setup-go@v3 + with: + go-version: '1.18' + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: v1.50.0 + args: --timeout=180s + skip-cache: true + + - name: Run tests + run: go test -v ./... + + - name: Run tests with race check + run: go test -v -race ./... diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..bd6b7eb --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,10 @@ +# Contribution Policy + +By making a contribution to this project: + +1. I assign any and all copyright related to the contribution to Proton AG; +2. I certify that the contribution was created in whole by me; +3. I understand and agree that this project and the contribution are public + and that a record of the contribution (including all personal information I + submit with it) is maintained indefinitely and may be redistributed with + this project or the open source license(s) involved. \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c6d23e4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2020 James Houlahan +Copyright (c) 2022 Proton AG + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..382b02a --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +# Go Proton API + +CI Status +GoDoc +Go Report Card +License + +This repository holds Go Proton API, a Go library implementing a client and development server for (a subset of) the Proton REST API. + +The license can be found in the [LICENSE](./LICENSE) file. + +For the contribution policy, see [CONTRIBUTING](./CONTRIBUTING.md). + +## Environment variables + +Most of the integration tests run locally. The ones that interact with Proton servers require the following environment variables set: + +- ```GO_PROTON_API_TEST_USERNAME``` +- ```GO_PROTON_API_TEST_PASSWORD``` + +## Contribution + +The library is maintained by Proton AG, and is not actively looking for contributors. diff --git a/address.go b/address.go new file mode 100644 index 0000000..ec4d330 --- /dev/null +++ b/address.go @@ -0,0 +1,46 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" + "golang.org/x/exp/slices" +) + +func (c *Client) GetAddresses(ctx context.Context) ([]Address, error) { + var res struct { + Addresses []Address + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/core/v4/addresses") + }); err != nil { + return nil, err + } + + slices.SortFunc(res.Addresses, func(a, b Address) bool { + return a.Order < b.Order + }) + + return res.Addresses, nil +} + +func (c *Client) GetAddress(ctx context.Context, addressID string) (Address, error) { + var res struct { + Address Address + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/core/v4/addresses/" + addressID) + }); err != nil { + return Address{}, err + } + + return res.Address, nil +} + +func (c *Client) OrderAddresses(ctx context.Context, req OrderAddressesReq) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).Put("/core/v4/addresses/order") + }) +} diff --git a/address_types.go b/address_types.go new file mode 100644 index 0000000..6affeb6 --- /dev/null +++ b/address_types.go @@ -0,0 +1,27 @@ +package proton + +type Address struct { + ID string + Email string + + Send Bool + Receive Bool + Status AddressStatus + + Order int + DisplayName string + + Keys Keys +} + +type OrderAddressesReq struct { + AddressIDs []string +} + +type AddressStatus int + +const ( + AddressStatusDisabled AddressStatus = iota + AddressStatusEnabled + AddressStatusDeleting +) diff --git a/atomic.go b/atomic.go new file mode 100644 index 0000000..5df7799 --- /dev/null +++ b/atomic.go @@ -0,0 +1,33 @@ +package proton + +import "sync/atomic" + +type atomicUint64 struct { + v uint64 +} + +func (x *atomicUint64) Load() uint64 { return atomic.LoadUint64(&x.v) } + +func (x *atomicUint64) Store(val uint64) { atomic.StoreUint64(&x.v, val) } + +func (x *atomicUint64) Swap(new uint64) (old uint64) { return atomic.SwapUint64(&x.v, new) } + +func (x *atomicUint64) Add(delta uint64) (new uint64) { return atomic.AddUint64(&x.v, delta) } + +type atomicBool struct { + v uint32 +} + +func (x *atomicBool) Load() bool { return atomic.LoadUint32(&x.v) != 0 } + +func (x *atomicBool) Store(val bool) { atomic.StoreUint32(&x.v, b32(val)) } + +func (x *atomicBool) Swap(new bool) (old bool) { return atomic.SwapUint32(&x.v, b32(new)) != 0 } + +func b32(b bool) uint32 { + if b { + return 1 + } + + return 0 +} diff --git a/attachment.go b/attachment.go new file mode 100644 index 0000000..3b0db61 --- /dev/null +++ b/attachment.go @@ -0,0 +1,79 @@ +package proton + +import ( + "bytes" + "context" + "fmt" + "io" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetAttachment(ctx context.Context, attachmentID string) ([]byte, error) { + return c.attPool().ProcessOne(ctx, attachmentID) +} + +func (c *Client) UploadAttachment(ctx context.Context, addrKR *crypto.KeyRing, req CreateAttachmentReq) (Attachment, error) { + var res struct { + Attachment Attachment + } + + sig, err := addrKR.SignDetached(crypto.NewPlainMessage(req.Body)) + if err != nil { + return Attachment{}, fmt.Errorf("failed to sign attachment: %w", err) + } + + enc, err := addrKR.EncryptAttachment(crypto.NewPlainMessage(req.Body), req.Filename) + if err != nil { + return Attachment{}, fmt.Errorf("failed to encrypt attachment: %w", err) + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res). + SetMultipartFormData(map[string]string{ + "MessageID": req.MessageID, + "Filename": req.Filename, + "MIMEType": string(req.MIMEType), + "Disposition": string(req.Disposition), + "ContentID": req.ContentID, + }). + SetMultipartFields( + &resty.MultipartField{ + Param: "KeyPackets", + FileName: "blob", + ContentType: "application/octet-stream", + Reader: bytes.NewReader(enc.KeyPacket), + }, + &resty.MultipartField{ + Param: "DataPacket", + FileName: "blob", + ContentType: "application/octet-stream", + Reader: bytes.NewReader(enc.DataPacket), + }, + &resty.MultipartField{ + Param: "Signature", + FileName: "blob", + ContentType: "application/octet-stream", + Reader: bytes.NewReader(sig.GetBinary()), + }, + ). + Post("/mail/v4/attachments") + }); err != nil { + return Attachment{}, err + } + + return res.Attachment, nil +} + +func (c *Client) getAttachment(ctx context.Context, attachmentID string) ([]byte, error) { + res, err := c.doRes(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetDoNotParseResponse(true).Get("/mail/v4/attachments/" + attachmentID) + }) + if err != nil { + return nil, err + } + defer res.RawBody().Close() + + return io.ReadAll(res.RawBody()) +} diff --git a/attachment_types.go b/attachment_types.go new file mode 100644 index 0000000..e0eed82 --- /dev/null +++ b/attachment_types.go @@ -0,0 +1,36 @@ +package proton + +import ( + "github.com/ProtonMail/gluon/rfc822" +) + +type Attachment struct { + ID string + + Name string + Size int64 + MIMEType rfc822.MIMEType + Disposition Disposition + Headers Headers + + KeyPackets string + Signature string +} + +type Disposition string + +const ( + InlineDisposition Disposition = "inline" + AttachmentDisposition Disposition = "attachment" +) + +type CreateAttachmentReq struct { + MessageID string + + Filename string + MIMEType rfc822.MIMEType + Disposition Disposition + ContentID string + + Body []byte +} diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..de14027 --- /dev/null +++ b/auth.go @@ -0,0 +1,45 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) Auth2FA(ctx context.Context, req Auth2FAReq) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).Post("/core/v4/auth/2fa") + }) +} + +func (c *Client) AuthDelete(ctx context.Context) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.Delete("/core/v4/auth") + }) +} + +func (c *Client) AuthSessions(ctx context.Context) ([]AuthSession, error) { + var res struct { + Sessions []AuthSession + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/auth/v4/sessions") + }); err != nil { + return nil, err + } + + return res.Sessions, nil +} + +func (c *Client) AuthRevoke(ctx context.Context, authUID string) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.Delete("/auth/v4/sessions/" + authUID) + }) +} + +func (c *Client) AuthRevokeAll(ctx context.Context) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.Delete("/auth/v4/sessions") + }) +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..bb38662 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,103 @@ +package proton_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" +) + +func TestAutomaticAuthRefresh(t *testing.T) { + wantAuth := proton.Auth{ + UID: "testUID", + AccessToken: "testAcc", + RefreshToken: "testRef", + } + + mux := http.NewServeMux() + + mux.HandleFunc("/core/v4/auth/refresh", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + if err := json.NewEncoder(w).Encode(wantAuth); err != nil { + panic(err) + } + }) + + mux.HandleFunc("/core/v4/users", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + ts := httptest.NewServer(mux) + defer ts.Close() + + var gotAuth proton.Auth + + // Create a new client. + c := proton.New(proton.WithHostURL(ts.URL)).NewClient("uid", "acc", "ref", time.Now().Add(-time.Second)) + defer c.Close() + + // Register an auth handler. + c.AddAuthHandler(func(auth proton.Auth) { gotAuth = auth }) + + // Make a request with an access token that already expired one second ago. + if _, err := c.GetUser(context.Background()); err != nil { + t.Fatal("got unexpected error", err) + } + + // The auth callback should have been called. + if !cmp.Equal(gotAuth, wantAuth) { + t.Fatal("got unexpected auth", gotAuth) + } +} + +func TestAuth(t *testing.T) { + s := server.New() + defer s.Close() + + _, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) + require.NoError(t, err) + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + // Create one session. + c1, auth1, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + require.NoError(t, err) + + // Revoke all other sessions. + require.NoError(t, c1.AuthRevokeAll(context.Background())) + + // Create another session. + c2, _, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + require.NoError(t, err) + + // There should be two sessions. + sessions, err := c1.AuthSessions(context.Background()) + require.NoError(t, err) + require.Len(t, sessions, 2) + + // Revoke the first session. + require.NoError(t, c2.AuthRevoke(context.Background(), auth1.UID)) + + // The first session should no longer work. + require.Error(t, c1.AuthDelete(context.Background())) + + // There should be one session remaining. + remaining, err := c2.AuthSessions(context.Background()) + require.NoError(t, err) + require.Len(t, remaining, 1) + + // Delete the last session. + require.NoError(t, c2.AuthDelete(context.Background())) +} diff --git a/auth_types.go b/auth_types.go new file mode 100644 index 0000000..1e536a9 --- /dev/null +++ b/auth_types.go @@ -0,0 +1,95 @@ +package proton + +type AuthInfoReq struct { + Username string +} + +type AuthInfo struct { + Version int + Modulus string + ServerEphemeral string + Salt string + SRPSession string + TwoFA TwoFAInfo `json:"2FA"` +} + +type U2FReq struct { + KeyHandle string + ClientData string + SignatureData string +} + +type AuthReq struct { + Username string + ClientEphemeral string + ClientProof string + SRPSession string + U2F U2FReq +} + +type Auth struct { + UserID string + + UID string + AccessToken string + RefreshToken string + ServerProof string + ExpiresIn int + + Scope string + TwoFA TwoFAInfo `json:"2FA"` + PasswordMode PasswordMode +} + +type RegisteredKey struct { + Version string + KeyHandle string +} + +type U2FInfo struct { + Challenge string + RegisteredKeys []RegisteredKey +} + +type TwoFAInfo struct { + Enabled TwoFAStatus + U2F U2FInfo +} + +type TwoFAStatus int + +const ( + TwoFADisabled TwoFAStatus = iota + TOTPEnabled +) + +type PasswordMode int + +const ( + OnePasswordMode PasswordMode = iota + 1 + TwoPasswordMode +) + +type Auth2FAReq struct { + TwoFactorCode string +} + +type AuthRefreshReq struct { + UID string + RefreshToken string + ResponseType string + GrantType string + RedirectURI string + State string +} + +type AuthSession struct { + UID string + CreateTime int64 + + ClientID string + MemberID string + Revocable Bool + + LocalizedClientName string +} diff --git a/block.go b/block.go new file mode 100644 index 0000000..7cad5bc --- /dev/null +++ b/block.go @@ -0,0 +1,19 @@ +package proton + +import ( + "context" + "io" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetBlock(ctx context.Context, url string) (io.ReadCloser, error) { + res, err := c.doRes(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetDoNotParseResponse(true).Get(url) + }) + if err != nil { + return nil, err + } + + return res.RawBody(), nil +} diff --git a/boolean.go b/boolean.go new file mode 100644 index 0000000..0f15f6d --- /dev/null +++ b/boolean.go @@ -0,0 +1,38 @@ +package proton + +import "encoding/json" + +// Bool is a convenience type for boolean values; it converts from APIBool to Go's builtin bool type. +type Bool bool + +// APIBool is the boolean type used by the API (0 or 1). +type APIBool int + +const ( + APIFalse APIBool = iota + APITrue +) + +func (b *Bool) UnmarshalJSON(data []byte) error { + var v APIBool + + if err := json.Unmarshal(data, &v); err != nil { + return err + } + + *b = Bool(v == APITrue) + + return nil +} + +func (b Bool) MarshalJSON() ([]byte, error) { + var v APIBool + + if b { + v = APITrue + } else { + v = APIFalse + } + + return json.Marshal(v) +} diff --git a/calendar.go b/calendar.go new file mode 100644 index 0000000..bd86eb2 --- /dev/null +++ b/calendar.go @@ -0,0 +1,77 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetCalendars(ctx context.Context) ([]Calendar, error) { + var res struct { + Calendars []Calendar + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1") + }); err != nil { + return nil, err + } + + return res.Calendars, nil +} + +func (c *Client) GetCalendar(ctx context.Context, calendarID string) (Calendar, error) { + var res struct { + Calendar Calendar + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1/" + calendarID) + }); err != nil { + return Calendar{}, err + } + + return res.Calendar, nil +} + +func (c *Client) GetCalendarKeys(ctx context.Context, calendarID string) (CalendarKeys, error) { + var res struct { + Keys CalendarKeys + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/keys") + }); err != nil { + return nil, err + } + + return res.Keys, nil +} + +func (c *Client) GetCalendarMembers(ctx context.Context, calendarID string) ([]CalendarMember, error) { + var res struct { + Members []CalendarMember + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/members") + }); err != nil { + return nil, err + } + + return res.Members, nil +} + +func (c *Client) GetCalendarPassphrase(ctx context.Context, calendarID string) (CalendarPassphrase, error) { + var res struct { + Passphrase CalendarPassphrase + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/passphrase") + }); err != nil { + return CalendarPassphrase{}, err + } + + return res.Passphrase, nil +} diff --git a/calendar_event.go b/calendar_event.go new file mode 100644 index 0000000..f8f9bed --- /dev/null +++ b/calendar_event.go @@ -0,0 +1,66 @@ +package proton + +import ( + "context" + "net/url" + "strconv" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) CountCalendarEvents(ctx context.Context, calendarID string) (int, error) { + var res struct { + Total int + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/events") + }); err != nil { + return 0, err + } + + return res.Total, nil +} + +// TODO: For now, the query params are partially constant -- should they be configurable? +func (c *Client) GetCalendarEvents(ctx context.Context, calendarID string, page, pageSize int, filter url.Values) ([]CalendarEvent, error) { + var res struct { + Events []CalendarEvent + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetQueryParams(map[string]string{ + "Page": strconv.Itoa(page), + "PageSize": strconv.Itoa(pageSize), + }).SetQueryParamsFromValues(filter).SetResult(&res).Get("/calendar/v1/" + calendarID + "/events") + }); err != nil { + return nil, err + } + + return res.Events, nil +} + +func (c *Client) GetAllCalendarEvents(ctx context.Context, calendarID string, filter url.Values) ([]CalendarEvent, error) { + total, err := c.CountCalendarEvents(ctx, calendarID) + if err != nil { + return nil, err + } + + return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]CalendarEvent, error) { + return c.GetCalendarEvents(ctx, calendarID, page, pageSize, filter) + }) +} + +func (c *Client) GetCalendarEvent(ctx context.Context, calendarID, eventID string) (CalendarEvent, error) { + var res struct { + Event CalendarEvent + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/events/" + eventID) + }); err != nil { + return CalendarEvent{}, err + } + + return res.Event, nil +} diff --git a/calendar_event_types.go b/calendar_event_types.go new file mode 100644 index 0000000..3931837 --- /dev/null +++ b/calendar_event_types.go @@ -0,0 +1,110 @@ +package proton + +import ( + "encoding/base64" + + "github.com/ProtonMail/gopenpgp/v2/crypto" +) + +type CalendarEvent struct { + ID string + UID string + CalendarID string + SharedEventID string + + CreateTime int64 + LastEditTime int64 + StartTime int64 + StartTimezone string + EndTime int64 + EndTimezone string + FullDay Bool + + Author string + Permissions CalendarPermissions + Attendees []CalendarAttendee + + SharedKeyPacket string + CalendarKeyPacket string + + SharedEvents []CalendarEventPart + CalendarEvents []CalendarEventPart + AttendeesEvents []CalendarEventPart + PersonalEvents []CalendarEventPart +} + +// TODO: Only personal events have MemberID; should we have a different type for that? +type CalendarEventPart struct { + MemberID string + + Type CalendarEventType + Data string + Signature string + Author string +} + +func (part CalendarEventPart) Decode(calKR *crypto.KeyRing, addrKR *crypto.KeyRing, kp []byte) error { + if part.Type&CalendarEventTypeEncrypted != 0 { + var enc *crypto.PGPMessage + + if kp != nil { + raw, err := base64.StdEncoding.DecodeString(part.Data) + if err != nil { + return err + } + + enc = crypto.NewPGPSplitMessage(kp, raw).GetPGPMessage() + } else { + var err error + + if enc, err = crypto.NewPGPMessageFromArmored(part.Data); err != nil { + return err + } + } + + dec, err := calKR.Decrypt(enc, nil, crypto.GetUnixTime()) + if err != nil { + return err + } + + part.Data = dec.GetString() + } + + if part.Type&CalendarEventTypeSigned != 0 { + sig, err := crypto.NewPGPSignatureFromArmored(part.Signature) + if err != nil { + return err + } + + if err := addrKR.VerifyDetached(crypto.NewPlainMessageFromString(part.Data), sig, crypto.GetUnixTime()); err != nil { + return err + } + } + + return nil +} + +type CalendarEventType int + +const ( + CalendarEventTypeClear CalendarEventType = iota + CalendarEventTypeEncrypted + CalendarEventTypeSigned +) + +type CalendarAttendee struct { + ID string + Token string + Status CalendarAttendeeStatus + Permissions CalendarPermissions +} + +// TODO: What is this? +type CalendarAttendeeStatus int + +const ( + CalendarAttendeeStatusPending CalendarAttendeeStatus = iota + CalendarAttendeeStatusMaybe + CalendarAttendeeStatusNo + CalendarAttendeeStatusYes +) diff --git a/calendar_types.go b/calendar_types.go new file mode 100644 index 0000000..635ef6a --- /dev/null +++ b/calendar_types.go @@ -0,0 +1,140 @@ +package proton + +import ( + "errors" + + "github.com/ProtonMail/gopenpgp/v2/crypto" +) + +type Calendar struct { + ID string + Name string + Description string + Color string + Display Bool + + Type CalendarType + Flags CalendarFlag +} + +type CalendarFlag int64 + +const ( + CalendarFlagActive CalendarFlag = 1 << iota + CalendarFlagUpdatePassphrase + CalendarFlagResetNeeded + CalendarFlagIncompleteSetup + CalendarFlagLostAccess +) + +type CalendarType int + +const ( + CalendarTypeNormal CalendarType = iota + CalendarTypeSubscribed +) + +type CalendarKey struct { + ID string + CalendarID string + PassphraseID string + PrivateKey string + Flags CalendarKeyFlag +} + +func (key CalendarKey) Unlock(passphrase []byte) (*crypto.Key, error) { + lockedKey, err := crypto.NewKeyFromArmored(key.PrivateKey) + if err != nil { + return nil, err + } + + return lockedKey.Unlock(passphrase) +} + +type CalendarKeys []CalendarKey + +func (keys CalendarKeys) Unlock(passphrase []byte) (*crypto.KeyRing, error) { + kr, err := crypto.NewKeyRing(nil) + if err != nil { + return nil, err + } + + for _, key := range keys { + if k, err := key.Unlock(passphrase); err != nil { + continue + } else if err := kr.AddKey(k); err != nil { + return nil, err + } + } + + return kr, nil +} + +// TODO: What is this? +type CalendarKeyFlag int64 + +const ( + CalendarKeyFlagActive CalendarKeyFlag = 1 << iota + CalendarKeyFlagPrimary +) + +type CalendarMember struct { + ID string + Permissions CalendarPermissions + Email string + Color string + Display Bool + CalendarID string +} + +// TODO: What is this? +type CalendarPermissions int + +// TODO: Support invitations. +type CalendarPassphrase struct { + ID string + Flags CalendarPassphraseFlag + MemberPassphrases []MemberPassphrase +} + +func (passphrase CalendarPassphrase) Decrypt(memberID string, addrKR *crypto.KeyRing) ([]byte, error) { + for _, passphrase := range passphrase.MemberPassphrases { + if passphrase.MemberID == memberID { + return passphrase.decrypt(addrKR) + } + } + + return nil, errors.New("no such member passphrase") +} + +// TODO: What is this? +type CalendarPassphraseFlag int64 + +type MemberPassphrase struct { + MemberID string + Passphrase string + Signature string +} + +func (passphrase MemberPassphrase) decrypt(addrKR *crypto.KeyRing) ([]byte, error) { + msg, err := crypto.NewPGPMessageFromArmored(passphrase.Passphrase) + if err != nil { + return nil, err + } + + sig, err := crypto.NewPGPSignatureFromArmored(passphrase.Signature) + if err != nil { + return nil, err + } + + dec, err := addrKR.Decrypt(msg, nil, crypto.GetUnixTime()) + if err != nil { + return nil, err + } + + if err := addrKR.VerifyDetached(dec, sig, crypto.GetUnixTime()); err != nil { + return nil, err + } + + return dec.GetBinary(), nil +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..364be27 --- /dev/null +++ b/client.go @@ -0,0 +1,205 @@ +package proton + +import ( + "context" + "fmt" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/bradenaw/juniper/xsync" + "github.com/go-resty/resty/v2" +) + +// clientID is a unique identifier for a client. +var clientID uint64 + +// AuthHandler is given any new auths that are returned from the API due to an unexpected auth refresh. +type AuthHandler func(Auth) + +// Handler is a generic function that can be registered for a certain event (e.g. deauth, API code). +type Handler func() + +// Client is the proton client. +type Client struct { + m *Manager + + // clientID is this client's unique ID. + clientID uint64 + + // attPool is the (lazy-initialized) pool of goroutines that fetch attachments. + attPool func() *Pool[string, []byte] + + uid string + acc string + ref string + exp time.Time + authLock sync.RWMutex + + authHandlers []AuthHandler + deauthHandlers []Handler + hookLock sync.RWMutex + + deauthOnce sync.Once +} + +func newClient(m *Manager, uid string) *Client { + c := &Client{ + m: m, + uid: uid, + clientID: atomic.AddUint64(&clientID, 1), + } + + c.attPool = xsync.Lazy(func() *Pool[string, []byte] { + return NewPool(m.attPoolSize, c.getAttachment) + }) + + return c +} + +func (c *Client) AddAuthHandler(handler AuthHandler) { + c.hookLock.Lock() + defer c.hookLock.Unlock() + + c.authHandlers = append(c.authHandlers, handler) +} + +func (c *Client) AddDeauthHandler(handler Handler) { + c.hookLock.Lock() + defer c.hookLock.Unlock() + + c.deauthHandlers = append(c.deauthHandlers, handler) +} + +func (c *Client) AddPreRequestHook(hook resty.RequestMiddleware) { + c.hookLock.Lock() + defer c.hookLock.Unlock() + + c.m.rc.OnBeforeRequest(func(rc *resty.Client, r *resty.Request) error { + if clientID, ok := ClientIDFromContext(r.Context()); !ok || clientID != c.clientID { + return nil + } + + return hook(rc, r) + }) +} + +func (c *Client) AddPostRequestHook(hook resty.ResponseMiddleware) { + c.hookLock.Lock() + defer c.hookLock.Unlock() + + c.m.rc.OnAfterResponse(func(rc *resty.Client, r *resty.Response) error { + if clientID, ok := ClientIDFromContext(r.Request.Context()); !ok || clientID != c.clientID { + return nil + } + + return hook(rc, r) + }) +} + +func (c *Client) Close() { + c.attPool().Done() + + c.authLock.Lock() + defer c.authLock.Unlock() + + c.uid = "" + c.acc = "" + c.ref = "" + c.exp = time.Time{} + + c.hookLock.Lock() + defer c.hookLock.Unlock() + + c.authHandlers = nil + c.deauthHandlers = nil +} + +func (c *Client) withAuth(acc, ref string, exp time.Time) *Client { + c.acc = acc + c.ref = ref + c.exp = exp + + return c +} + +func (c *Client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) error { + if _, err := c.doRes(ctx, fn); err != nil { + return err + } + + return nil +} + +func (c *Client) doRes(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) { + c.hookLock.RLock() + defer c.hookLock.RUnlock() + + req, err := c.newReq(ctx) + if err != nil { + return nil, err + } + + // Perform the request. + res, err := fn(req) + + // If we receive no response, we can't do anything. + if res.RawResponse == nil { + return nil, fmt.Errorf("received no response from API: %w", err) + } + + // If we receive a 401, notify deauth handlers. + if res.StatusCode() == http.StatusUnauthorized { + c.deauthOnce.Do(func() { + for _, handler := range c.deauthHandlers { + handler() + } + }) + } + + return res, err +} + +func (c *Client) newReq(ctx context.Context) (*resty.Request, error) { + c.authLock.Lock() + defer c.authLock.Unlock() + + r := c.m.r(WithClient(ctx, c.clientID)) + + if c.uid != "" { + r.SetHeader("x-pm-uid", c.uid) + } + + if time.Now().After(c.exp) { + auth, err := c.m.authRefresh(ctx, c.uid, c.ref) + if err != nil { + return nil, err + } + + c.acc = auth.AccessToken + c.ref = auth.RefreshToken + c.exp = time.Now().Add(time.Duration(auth.ExpiresIn) * time.Second) + + if err := c.handleAuth(auth); err != nil { + return nil, err + } + } + + if c.acc != "" { + r.SetAuthToken(c.acc) + } + + return r, nil +} + +func (c *Client) handleAuth(auth Auth) error { + c.hookLock.RLock() + defer c.hookLock.RUnlock() + + for _, handler := range c.authHandlers { + handler(auth) + } + + return nil +} diff --git a/contact.go b/contact.go new file mode 100644 index 0000000..91be119 --- /dev/null +++ b/contact.go @@ -0,0 +1,141 @@ +package proton + +import ( + "context" + "strconv" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetContact(ctx context.Context, contactID string) (Contact, error) { + var res struct { + Contact Contact + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/contacts/v4/" + contactID) + }); err != nil { + return Contact{}, err + } + + return res.Contact, nil +} + +func (c *Client) CountContacts(ctx context.Context) (int, error) { + var res struct { + Total int + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/contacts/v4") + }); err != nil { + return 0, err + } + + return res.Total, nil +} + +func (c *Client) CountContactEmails(ctx context.Context, email string) (int, error) { + var res struct { + Total int + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).SetQueryParam("Email", email).Get("/contacts/v4/emails") + }); err != nil { + return 0, err + } + + return res.Total, nil +} + +func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) { + var res struct { + Contacts []Contact + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetQueryParams(map[string]string{ + "Page": strconv.Itoa(page), + "PageSize": strconv.Itoa(pageSize), + }).SetResult(&res).Get("/contacts/v4") + }); err != nil { + return nil, err + } + + return res.Contacts, nil +} + +func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) { + total, err := c.CountContacts(ctx) + if err != nil { + return nil, err + } + + return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]Contact, error) { + return c.GetContacts(ctx, page, pageSize) + }) +} + +func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) { + var res struct { + ContactEmails []ContactEmail + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetQueryParams(map[string]string{ + "Page": strconv.Itoa(page), + "PageSize": strconv.Itoa(pageSize), + "Email": email, + }).SetResult(&res).Get("/contacts/v4/emails") + }); err != nil { + return nil, err + } + + return res.ContactEmails, nil +} + +func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) { + total, err := c.CountContactEmails(ctx, email) + if err != nil { + return nil, err + } + + return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) { + return c.GetContactEmails(ctx, email, page, pageSize) + }) +} + +func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) { + var res struct { + Responses []CreateContactsRes + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).Post("/contacts/v4") + }); err != nil { + return nil, err + } + + return res.Responses, nil +} + +func (c *Client) UpdateContact(ctx context.Context, contactID string, req UpdateContactReq) (Contact, error) { + var res struct { + Contact Contact + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).Put("/contacts/v4/" + contactID) + }); err != nil { + return Contact{}, err + } + + return res.Contact, nil +} + +func (c *Client) DeleteContacts(ctx context.Context, req DeleteContactsReq) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).Put("/contacts/v4/delete") + }) +} diff --git a/contact_card.go b/contact_card.go new file mode 100644 index 0000000..bab35ba --- /dev/null +++ b/contact_card.go @@ -0,0 +1,374 @@ +package proton + +import ( + "bytes" + "errors" + "strings" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/xslices" + "github.com/emersion/go-vcard" +) + +const ( + FieldPMScheme = "X-PM-SCHEME" + FieldPMSign = "X-PM-SIGN" + FieldPMEncrypt = "X-PM-ENCRYPT" + FieldPMMIMEType = "X-PM-MIMETYPE" +) + +type Cards []*Card + +func (c *Cards) Merge(kr *crypto.KeyRing) (vcard.Card, error) { + merged := newVCard() + + for _, card := range *c { + dec, err := card.decode(kr) + if err != nil { + return nil, err + } + + for k, fields := range dec { + for _, f := range fields { + merged.Add(k, f) + } + } + } + + return merged, nil +} + +func (c *Cards) Get(cardType CardType) (*Card, bool) { + for _, card := range *c { + if card.Type == cardType { + return card, true + } + } + + return nil, false +} + +type Card struct { + Type CardType + Data string + Signature string +} + +type CardType int + +const ( + CardTypeClear CardType = iota + CardTypeEncrypted + CardTypeSigned +) + +func NewCard(kr *crypto.KeyRing, cardType CardType) (*Card, error) { + card := &Card{Type: cardType} + + if err := card.encode(kr, newVCard()); err != nil { + return nil, err + } + + return card, nil +} + +func newVCard() vcard.Card { + card := make(vcard.Card) + + card.AddValue(vcard.FieldVersion, "4.0") + + return card +} + +func (c Card) Get(kr *crypto.KeyRing, key string) ([]*vcard.Field, error) { + dec, err := c.decode(kr) + if err != nil { + return nil, err + } + + return dec[key], nil +} + +func (c *Card) Set(kr *crypto.KeyRing, key, value string) error { + dec, err := c.decode(kr) + if err != nil { + return err + } + + if field := dec.Get(key); field != nil { + field.Value = value + + return c.encode(kr, dec) + } + + dec.AddValue(key, value) + + return c.encode(kr, dec) +} + +func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error { + dec, err := c.decode(kr) + if err != nil { + return err + } + + c.Type = cardType + + return c.encode(kr, dec) +} + +// GetGroup returns a type to manipulate the group defined by the given key/value pair. +func (c Card) GetGroup(kr *crypto.KeyRing, groupKey, groupValue string) (CardGroup, error) { + group, err := c.getGroup(kr, groupKey, groupValue) + if err != nil { + return CardGroup{}, err + } + + return CardGroup{Card: c, kr: kr, group: group}, nil +} + +// DeleteGroup removes all values in the group defined by the given key/value pair. +func (c *Card) DeleteGroup(kr *crypto.KeyRing, groupKey, groupValue string) error { + group, err := c.getGroup(kr, groupKey, groupValue) + if err != nil { + return err + } + + return c.deleteGroup(kr, group) +} + +type CardGroup struct { + Card + + kr *crypto.KeyRing + group string +} + +// Get returns the values in the group with the given key. +func (g CardGroup) Get(key string) ([]string, error) { + dec, err := g.decode(g.kr) + if err != nil { + return nil, err + } + + var fields []*vcard.Field + + for _, field := range dec[key] { + if field.Group != g.group { + continue + } + + fields = append(fields, field) + } + + return xslices.Map(fields, func(field *vcard.Field) string { + return field.Value + }), nil +} + +// Set sets the value in the group. +func (g *CardGroup) Set(key, value string, params vcard.Params) error { + dec, err := g.decode(g.kr) + if err != nil { + return err + } + + for _, field := range dec[key] { + if field.Group != g.group { + continue + } + + field.Value = value + + return g.encode(g.kr, dec) + } + + dec.Add(key, &vcard.Field{ + Value: value, + Group: g.group, + Params: params, + }) + + return g.encode(g.kr, dec) +} + +// Add adds a value to the group. +func (g *CardGroup) Add(key, value string, params vcard.Params) error { + dec, err := g.decode(g.kr) + if err != nil { + return err + } + + dec.Add(key, &vcard.Field{ + Value: value, + Group: g.group, + Params: params, + }) + + return g.encode(g.kr, dec) +} + +// Remove removes the value in the group with the given key/value. +func (g *CardGroup) Remove(key, value string) error { + dec, err := g.decode(g.kr) + if err != nil { + return err + } + + fields, ok := dec[key] + if !ok { + return errors.New("no such key") + } + + var rest []*vcard.Field + + for _, field := range fields { + if field.Group != g.group { + rest = append(rest, field) + } else if field.Value != value { + rest = append(rest, field) + } + } + + if len(rest) > 0 { + dec[key] = rest + } else { + delete(dec, key) + } + + return g.encode(g.kr, dec) +} + +// RemoveAll removes all values in the group with the given key. +func (g *CardGroup) RemoveAll(key string) error { + dec, err := g.decode(g.kr) + if err != nil { + return err + } + + fields, ok := dec[key] + if !ok { + return errors.New("no such key") + } + + var rest []*vcard.Field + + for _, field := range fields { + if field.Group != g.group { + rest = append(rest, field) + } + } + + if len(rest) > 0 { + dec[key] = rest + } else { + delete(dec, key) + } + + return g.encode(g.kr, dec) +} + +func (c Card) getGroup(kr *crypto.KeyRing, groupKey, groupValue string) (string, error) { + fields, err := c.Get(kr, groupKey) + if err != nil { + return "", err + } + + for _, field := range fields { + if field.Value != groupValue { + continue + } + + return field.Group, nil + } + + return "", errors.New("no such field") +} + +func (c *Card) deleteGroup(kr *crypto.KeyRing, group string) error { + dec, err := c.decode(kr) + if err != nil { + return err + } + + for key, fields := range dec { + var rest []*vcard.Field + + for _, field := range fields { + if field.Group != group { + rest = append(rest, field) + } + } + + if len(rest) > 0 { + dec[key] = rest + } else { + delete(dec, key) + } + } + + return c.encode(kr, dec) +} + +func (c Card) decode(kr *crypto.KeyRing) (vcard.Card, error) { + if c.Type&CardTypeEncrypted != 0 { + enc, err := crypto.NewPGPMessageFromArmored(c.Data) + if err != nil { + return nil, err + } + + dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime()) + if err != nil { + return nil, err + } + + c.Data = dec.GetString() + } + + if c.Type&CardTypeSigned != 0 { + sig, err := crypto.NewPGPSignatureFromArmored(c.Signature) + if err != nil { + return nil, err + } + + if err := kr.VerifyDetached(crypto.NewPlainMessageFromString(c.Data), sig, crypto.GetUnixTime()); err != nil { + return nil, err + } + } + + return vcard.NewDecoder(strings.NewReader(c.Data)).Decode() +} + +func (c *Card) encode(kr *crypto.KeyRing, card vcard.Card) error { + buf := new(bytes.Buffer) + + if err := vcard.NewEncoder(buf).Encode(card); err != nil { + return err + } + + if c.Type&CardTypeSigned != 0 { + sig, err := kr.SignDetached(crypto.NewPlainMessageFromString(buf.String())) + if err != nil { + return err + } + + if c.Signature, err = sig.GetArmored(); err != nil { + return err + } + } + + if c.Type&CardTypeEncrypted != 0 { + enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(buf.String()), nil) + if err != nil { + return err + } + + if c.Data, err = enc.GetArmored(); err != nil { + return err + } + } else { + c.Data = buf.String() + } + + return nil +} diff --git a/contact_types.go b/contact_types.go new file mode 100644 index 0000000..24a0e7a --- /dev/null +++ b/contact_types.go @@ -0,0 +1,171 @@ +package proton + +import ( + "encoding/base64" + "strconv" + "strings" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/emersion/go-vcard" +) + +type RecipientType int + +const ( + RecipientTypeInternal RecipientType = iota + 1 + RecipientTypeExternal +) + +type ContactSettings struct { + MIMEType *rfc822.MIMEType + Scheme *EncryptionScheme + Sign *bool + Encrypt *bool + Keys []*crypto.Key +} + +type Contact struct { + ContactMetadata + ContactCards +} + +func (c *Contact) GetSettings(kr *crypto.KeyRing, email string) (ContactSettings, error) { + signedCard, ok := c.Cards.Get(CardTypeSigned) + if !ok { + return ContactSettings{}, nil + } + + group, err := signedCard.GetGroup(kr, vcard.FieldEmail, email) + if err != nil { + return ContactSettings{}, nil + } + + var settings ContactSettings + + scheme, err := group.Get(FieldPMScheme) + if err != nil { + return ContactSettings{}, err + } + + if len(scheme) > 0 { + switch scheme[0] { + case "pgp-inline": + settings.Scheme = newPtr(PGPInlineScheme) + + case "pgp-mime": + settings.Scheme = newPtr(PGPMIMEScheme) + } + } + + mimeType, err := group.Get(FieldPMMIMEType) + if err != nil { + return ContactSettings{}, err + } + + if len(mimeType) > 0 { + settings.MIMEType = newPtr(rfc822.MIMEType(mimeType[0])) + } + + sign, err := group.Get(FieldPMSign) + if err != nil { + return ContactSettings{}, err + } + + if len(sign) > 0 { + sign, err := strconv.ParseBool(sign[0]) + if err != nil { + return ContactSettings{}, err + } + + settings.Sign = newPtr(sign) + } + + encrypt, err := group.Get(FieldPMEncrypt) + if err != nil { + return ContactSettings{}, err + } + + if len(encrypt) > 0 { + encrypt, err := strconv.ParseBool(encrypt[0]) + if err != nil { + return ContactSettings{}, err + } + + settings.Encrypt = newPtr(encrypt) + } + + keys, err := group.Get(vcard.FieldKey) + if err != nil { + return ContactSettings{}, err + } + + if len(keys) > 0 { + for _, key := range keys { + dec, err := base64.StdEncoding.DecodeString(strings.SplitN(key, ",", 2)[1]) + if err != nil { + return ContactSettings{}, err + } + + pubKey, err := crypto.NewKey(dec) + if err != nil { + return ContactSettings{}, err + } + + settings.Keys = append(settings.Keys, pubKey) + } + } + + return settings, nil +} + +type ContactMetadata struct { + ID string + Name string + UID string + Size int64 + CreateTime int64 + ModifyTime int64 + ContactEmails []ContactEmail + LabelIDs []string +} + +type ContactCards struct { + Cards Cards +} + +type ContactEmail struct { + ID string + Name string + Email string + Type []string + ContactID string + LabelIDs []string +} + +type CreateContactsReq struct { + Contacts []ContactCards + Overwrite int + Labels int +} + +type CreateContactsRes struct { + Index int + + Response struct { + Error + Contact Contact + } +} + +type UpdateContactReq struct { + Cards Cards +} + +type DeleteContactsReq struct { + IDs []string +} + +func newPtr[T any](v T) *T { + return &v +} diff --git a/contexts.go b/contexts.go new file mode 100644 index 0000000..2657fc9 --- /dev/null +++ b/contexts.go @@ -0,0 +1,22 @@ +package proton + +import "context" + +type withClientKeyType struct{} + +var withClientKey withClientKeyType + +// WithClient marks this context as originating from the client with the given ID. +func WithClient(parent context.Context, clientID uint64) context.Context { + return context.WithValue(parent, withClientKey, clientID) +} + +// ClientIDFromContext returns true if this context was marked as originating from a client. +func ClientIDFromContext(ctx context.Context) (uint64, bool) { + clientID, ok := ctx.Value(withClientKey).(uint64) + if !ok { + return 0, false + } + + return clientID, true +} diff --git a/dialer.go b/dialer.go new file mode 100644 index 0000000..66c8eb1 --- /dev/null +++ b/dialer.go @@ -0,0 +1,311 @@ +package proton + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" +) + +func InsecureTransport() *http.Transport { + return &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } +} + +// NetCtl can be used to control whether a dialer can dial, and whether the resulting +// connection can read or write. +type NetCtl struct { + canDial atomicBool + dialLimit atomicUint64 + + canRead atomicBool + readLimit atomicUint64 + + canWrite atomicBool + writeLimit atomicUint64 + + onDial []func(net.Conn) + onRead []func([]byte) + onWrite []func([]byte) + + lock sync.Mutex +} + +// NewNetCtl returns a new NetCtl with all fields set to true. +func NewNetCtl() *NetCtl { + return &NetCtl{ + canDial: atomicBool{b32(true)}, + canRead: atomicBool{b32(true)}, + canWrite: atomicBool{b32(true)}, + } +} + +// SetCanDial sets whether the dialer can dial. +func (c *NetCtl) SetCanDial(canDial bool) { + c.canDial.Store(canDial) +} + +// SetDialLimit sets the maximum number of times dialers using this controller can dial. +func (c *NetCtl) SetDialLimit(limit uint64) { + c.dialLimit.Store(limit) +} + +// SetCanRead sets whether the connection can read. +func (c *NetCtl) SetCanRead(canRead bool) { + c.canRead.Store(canRead) +} + +// SetReadLimit sets the maximum number of bytes that can be read. +func (c *NetCtl) SetReadLimit(limit uint64) { + c.readLimit.Store(limit) +} + +// SetCanWrite sets whether the connection can write. +func (c *NetCtl) SetCanWrite(canWrite bool) { + c.canWrite.Store(canWrite) +} + +// SetWriteLimit sets the maximum number of bytes that can be written. +func (c *NetCtl) SetWriteLimit(limit uint64) { + c.writeLimit.Store(limit) +} + +// OnDial adds a callback that is called with the created connection when a dial is successful. +func (c *NetCtl) OnDial(f func(net.Conn)) { + c.lock.Lock() + defer c.lock.Unlock() + + c.onDial = append(c.onDial, f) +} + +// OnRead adds a callback that is called with the read bytes when a read is successful. +func (c *NetCtl) OnRead(f func([]byte)) { + c.lock.Lock() + defer c.lock.Unlock() + + c.onRead = append(c.onRead, f) +} + +// OnWrite adds a callback that is called with the written bytes when a write is successful. +func (c *NetCtl) OnWrite(f func([]byte)) { + c.lock.Lock() + defer c.lock.Unlock() + + c.onWrite = append(c.onWrite, f) +} + +// Disable is equivalent to disallowing dial, read and write. +func (c *NetCtl) Disable() { + c.SetCanDial(false) + c.SetCanRead(false) + c.SetCanWrite(false) +} + +// Enable is equivalent to allowing dial, read and write. +func (c *NetCtl) Enable() { + c.SetCanDial(true) + c.SetCanRead(true) + c.SetCanWrite(true) +} + +// Conn is a wrapper around net.Conn that can be used to control whether a connection can read or write. +type Conn struct { + net.Conn + + ctl *NetCtl + + readLimiter *readLimiter + writeLimiter *writeLimiter +} + +// Read reads from the wrapped connection, but only if the controller allows it. +func (c *Conn) Read(b []byte) (int, error) { + if !c.ctl.canRead.Load() { + return 0, errors.New("cannot read") + } + + n, err := c.readLimiter.read(c.Conn, b) + if err != nil { + return n, err + } + + for _, f := range c.ctl.onRead { + f(b[:n]) + } + + return n, err +} + +// Write writes to the wrapped connection, but only if the controller allows it. +func (c *Conn) Write(b []byte) (int, error) { + if !c.ctl.canWrite.Load() { + return 0, errors.New("cannot write") + } + + n, err := c.writeLimiter.write(c.Conn, b) + if err != nil { + return n, err + } + + for _, f := range c.ctl.onWrite { + f(b[:n]) + } + + return n, err +} + +// Dialer performs network dialing, but only if the controller allows it. +type Dialer struct { + ctl *NetCtl + + netDialer *net.Dialer + tlsDialer *tls.Dialer + tlsConfig *tls.Config + + readLimiter *readLimiter + writeLimiter *writeLimiter + + dialCount atomicUint64 +} + +// NewDialer returns a new dialer using the given net controller. +// It optionally uses a provided tls config. +func NewDialer(ctl *NetCtl, tlsConfig *tls.Config) *Dialer { + return &Dialer{ + ctl: ctl, + + netDialer: &net.Dialer{}, + tlsDialer: &tls.Dialer{Config: tlsConfig}, + tlsConfig: tlsConfig, + + readLimiter: newReadLimiter(ctl), + writeLimiter: newWriteLimiter(ctl), + + dialCount: atomicUint64{0}, + } +} + +// DialContext dials a network connection, but only if the controller allows it. +func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return d.dialWithDialer(ctx, network, addr, d.netDialer) +} + +// DialTLSContext dials a TLS network connection, but only if the controller allows it. +func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + return d.dialWithDialer(ctx, network, addr, d.tlsDialer) +} + +// dialWithDialer dials a network connection using the given dialer, but only if the controller allows it. +func (d *Dialer) dialWithDialer(ctx context.Context, network, addr string, dialer dialer) (net.Conn, error) { + if !d.ctl.canDial.Load() { + return nil, errors.New("cannot dial") + } + + if limit := d.ctl.dialLimit.Load(); limit > 0 && d.dialCount.Load() >= limit { + return nil, errors.New("dial limit reached") + } else { + d.dialCount.Add(1) + } + + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + d.ctl.lock.Lock() + defer d.ctl.lock.Unlock() + + for _, f := range d.ctl.onDial { + f(conn) + } + + return &Conn{ + Conn: conn, + ctl: d.ctl, + + readLimiter: d.readLimiter, + writeLimiter: d.writeLimiter, + }, nil +} + +// GetRoundTripper returns a new http.RoundTripper that uses the dialer. +func (d *Dialer) GetRoundTripper() http.RoundTripper { + return &http.Transport{ + DialContext: d.DialContext, + DialTLSContext: d.DialTLSContext, + TLSClientConfig: d.tlsConfig, + } +} + +type dialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + +type readLimiter struct { + ctl *NetCtl + + count atomicUint64 +} + +// newReadLimiter returns a new io.Reader that reads from r, but only up to limit bytes. +func newReadLimiter(ctl *NetCtl) *readLimiter { + return &readLimiter{ + ctl: ctl, + } +} + +func (limiter *readLimiter) read(r io.Reader, b []byte) (int, error) { + if limit := limiter.ctl.readLimit.Load(); limit > 0 && limiter.count.Load() >= limit { + return 0, fmt.Errorf("refusing to read: read limit reached") + } + + n, err := r.Read(b) + if err != nil { + return n, err + } + + if limit := limiter.ctl.readLimit.Load(); limit > 0 { + if new := limiter.count.Add(uint64(n)); new >= limit { + return 0, fmt.Errorf("read failed: read limit reached") + } + } + + return n, err +} + +type writeLimiter struct { + ctl *NetCtl + + count atomicUint64 +} + +// newWriteLimiter returns a new io.Writer that writes to w, but only up to limit bytes. +func newWriteLimiter(ctl *NetCtl) *writeLimiter { + return &writeLimiter{ + ctl: ctl, + } +} + +func (limiter *writeLimiter) write(w io.Writer, b []byte) (int, error) { + if limit := limiter.ctl.writeLimit.Load(); limit > 0 && limiter.count.Load() >= limit { + return 0, fmt.Errorf("refusing to write: write limit reached") + } + + n, err := w.Write(b) + if err != nil { + return n, err + } + + if limit := limiter.ctl.writeLimit.Load(); limit > 0 { + if new := limiter.count.Add(uint64(n)); new >= limit { + return 0, fmt.Errorf("write failed: write limit reached") + } + } + + return n, err +} diff --git a/dialer_test.go b/dialer_test.go new file mode 100644 index 0000000..6ef32c2 --- /dev/null +++ b/dialer_test.go @@ -0,0 +1,79 @@ +package proton_test + +import ( + "bytes" + "crypto/tls" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/ProtonMail/go-proton-api" +) + +func TestNetCtl_ReadLimit(t *testing.T) { + // Create a test http server that writes 100 bytes. + // Including the header, this is 217 bytes (100 bytes + 117 bytes). + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write(make([]byte, 100)); err != nil { + t.Fatal(err) + } + })) + defer ts.Close() + + // Create a new net controller. + netCtl := proton.NewNetCtl() + + // Create a new http client with the dialer. + client := &http.Client{ + Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + } + + // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. + netCtl.SetReadLimit(300) + + // This should succeed. + if resp, err := client.Get(ts.URL); err != nil { + t.Fatal(err) + } else { + resp.Body.Close() + } + + // This should fail. + if _, err := client.Get(ts.URL); err == nil { + t.Fatal("expected error") + } +} + +func TestNetCtl_WriteLimit(t *testing.T) { + // Create a test http server that reads the given body. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := io.ReadAll(r.Body); err != nil { + t.Fatal(err) + } + })) + defer ts.Close() + + // Create a new net controller. + netCtl := proton.NewNetCtl() + + // Create a new http client with the dialer. + client := &http.Client{ + Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + } + + // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. + netCtl.SetWriteLimit(300) + + // This should succeed. + if resp, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err != nil { + t.Fatal(err) + } else { + resp.Body.Close() + } + + // This should fail. + if _, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err == nil { + t.Fatal("expected error") + } +} diff --git a/drive_types.go b/drive_types.go new file mode 100644 index 0000000..d7a3ac1 --- /dev/null +++ b/drive_types.go @@ -0,0 +1,204 @@ +package proton + +import "github.com/ProtonMail/gopenpgp/v2/crypto" + +type Volume struct { + ID string // Encrypted volume ID + Name string // The volume name + OwnerUserID string // Encrypted owner user ID + UsedSpace int64 // Space used by files in the volume in bytes + MaxSpace int64 // Space limit for the volume in bytes + State VolumeState // TODO: What is this? +} + +type VolumeState int + +const ( +// TODO: VolumeState constants +) + +type Share struct { + ShareID string // Encrypted share ID + Type ShareType // Type of share + State ShareState // TODO: What is this? + PermissionsMask Permissions // Mask restricting member permissions on the share + LinkID string // Encrypted link ID to which the share points (root of share). + LinkType LinkType // TODO: What is this? + VolumeID string // Encrypted volume ID on which the share is mounted + Creator string // Creator address + AddressID string + Flags ShareFlags // The flag bitmap, with the following values + BlockSize int64 // TODO: What is this? + Locked bool // TODO: What is this? + Key string // The private key, encrypted with a passphrase + Passphrase string // The encrypted passphrase + PassphraseSignature string // The signature of the passphrase +} + +func (s Share) GetKeyRing(kr *crypto.KeyRing) (*crypto.KeyRing, error) { + encPass, err := crypto.NewPGPMessageFromArmored(s.Passphrase) + if err != nil { + return nil, err + } + + decPass, err := kr.Decrypt(encPass, nil, crypto.GetUnixTime()) + if err != nil { + return nil, err + } + + lockedKey, err := crypto.NewKeyFromArmored(s.Key) + if err != nil { + return nil, err + } + + unlockedKey, err := lockedKey.Unlock(decPass.GetBinary()) + if err != nil { + return nil, err + } + + return crypto.NewKeyRing(unlockedKey) +} + +type ShareType int + +const ( +// TODO: ShareType constants +) + +type ShareState int + +const ( +// TODO: ShareState constants +) + +type ShareFlags int + +const ( + NoFlags ShareFlags = iota + PrimaryShare +) + +type Link struct { + LinkID string // Encrypted file/folder ID + ParentLinkID string // Encrypted parent folder ID (LinkID) + Type LinkType + Name string // Encrypted file name + Hash string // HMAC of name encrypted with parent hash key + State LinkState // State of the link + ExpirationTime int64 + Size int64 + MIMEType string + Attributes Attributes + Permissions Permissions + + NodeKey string + NodePassphrase string + NodePassphraseSignature string + SignatureAddress string + + CreateTime int64 + ModifyTime int64 + + FileProperties FileProperties + FolderProperties FolderProperties +} + +func (l Link) GetKeyRing(kr *crypto.KeyRing) (*crypto.KeyRing, error) { + encPass, err := crypto.NewPGPMessageFromArmored(l.NodePassphrase) + if err != nil { + return nil, err + } + + decPass, err := kr.Decrypt(encPass, nil, crypto.GetUnixTime()) + if err != nil { + return nil, err + } + + lockedKey, err := crypto.NewKeyFromArmored(l.NodeKey) + if err != nil { + return nil, err + } + + unlockedKey, err := lockedKey.Unlock(decPass.GetBinary()) + if err != nil { + return nil, err + } + + return crypto.NewKeyRing(unlockedKey) +} + +type FileProperties struct { + ContentKeyPacket string + ActiveRevision Revision +} + +type FolderProperties struct{} + +type LinkType int + +const ( + FolderLinkType LinkType = iota + 1 + FileLinkType +) + +type LinkState int + +const ( + DraftLinkState LinkState = iota + ActiveLinkState + TrashedLinkState + DeletedLinkState +) + +type Revision struct { + ID string // Encrypted Revision ID + CreateTime int64 // Unix timestamp of the revision creation time + Size int64 // Size of the file in bytes + ManifestSignature string // The signature of the root hash + SignatureAddress string // The address used to sign the root hash + State FileRevisionState // State of revision + Blocks []Block +} + +type FileRevisionState int + +const ( + DraftRevisionState FileRevisionState = iota + ActiveRevisionState + ObsoleteRevisionState +) + +type Block struct { + Index int + URL string + EncSignature string + SignatureEmail string +} + +type LinkEvent struct { + EventID string // Encrypted ID of the Event + CreateTime int64 // Time stamp of the creation time of the Event + EventType LinkEventType // Type of event +} + +type LinkEventType int + +const ( + DeleteLinkEvent LinkEventType = iota + CreateLinkEvent + UpdateContentsLinkEvent + UpdateMetadataLinkEvent +) + +type Permissions int + +const ( + NoPermissions Permissions = 1 << iota + ReadPermission + WritePermission + AdministerMembersPermission + AdminPermission + SuperAdminPermission +) + +type Attributes uint32 diff --git a/event.go b/event.go new file mode 100644 index 0000000..8cfe6db --- /dev/null +++ b/event.go @@ -0,0 +1,102 @@ +package proton + +import ( + "context" + "time" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetLatestEventID(ctx context.Context) (string, error) { + var res struct { + Event + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/core/v4/events/latest") + }); err != nil { + return "", err + } + + return res.EventID, nil +} + +func (c *Client) GetEvent(ctx context.Context, eventID string) (Event, error) { + event, more, err := c.getEvent(ctx, eventID) + if err != nil { + return Event{}, err + } + + for more { + var next Event + + next, more, err = c.getEvent(ctx, event.EventID) + if err != nil { + return Event{}, err + } + + if err := event.merge(next); err != nil { + return Event{}, err + } + } + + return event, nil +} + +// NewEventStreamer returns a new event stream. +// It polls the API for new events at random intervals between `period` and `period+jitter`. +func (c *Client) NewEventStream(ctx context.Context, period, jitter time.Duration, lastEventID string) <-chan Event { + eventCh := make(chan Event) + + go func() { + defer close(eventCh) + + ticker := NewTicker(period, jitter) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + + case <-ticker.C: + // ... + } + + event, err := c.GetEvent(ctx, lastEventID) + if err != nil { + continue + } + + if event.EventID == lastEventID { + continue + } + + select { + case <-ctx.Done(): + return + + case eventCh <- event: + lastEventID = event.EventID + } + } + }() + + return eventCh +} + +func (c *Client) getEvent(ctx context.Context, eventID string) (Event, bool, error) { + var res struct { + Event + + More Bool + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/core/v4/events/" + eventID) + }); err != nil { + return Event{}, false, err + } + + return res.Event, bool(res.More), nil +} diff --git a/event_test.go b/event_test.go new file mode 100644 index 0000000..c5782ed --- /dev/null +++ b/event_test.go @@ -0,0 +1,70 @@ +package proton_test + +import ( + "context" + "testing" + "time" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server" + "github.com/stretchr/testify/require" +) + +func TestEventStreamer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := server.New() + defer s.Close() + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.InsecureTransport()), + ) + + _, _, err := s.CreateUser("username", "email@pm.me", []byte("password")) + require.NoError(t, err) + + c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password")) + require.NoError(t, err) + + createTestMessages(t, c, "password", 10) + + latestEventID, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + + eventCh := make(chan proton.Event) + + go func() { + for event := range c.NewEventStream(ctx, time.Second, 0, latestEventID) { + eventCh <- event + } + }() + + // Perform some action to generate an event. + metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) + require.NoError(t, err) + require.NoError(t, c.LabelMessages(ctx, []string{metadata[0].ID}, proton.TrashLabel)) + + // Wait for the first event. + <-eventCh + + // Close the client; this should stop the client's event streamer. + c.Close() + + // Create a new client and perform some actions with it to generate more events. + cc, _, err := m.NewClientWithLogin(ctx, "username", []byte("password")) + require.NoError(t, err) + defer cc.Close() + + require.NoError(t, cc.LabelMessages(ctx, []string{metadata[1].ID}, proton.TrashLabel)) + + // We should not receive any more events from the original client. + select { + case <-eventCh: + require.Fail(t, "received unexpected event") + + default: + // ... + } +} diff --git a/event_types.go b/event_types.go new file mode 100644 index 0000000..a8ae826 --- /dev/null +++ b/event_types.go @@ -0,0 +1,140 @@ +package proton + +import ( + "fmt" + "strings" + + "github.com/bradenaw/juniper/xslices" +) + +type Event struct { + EventID string + + Refresh RefreshFlag + + User *User + + MailSettings *MailSettings + + Messages []MessageEvent + + Labels []LabelEvent + + Addresses []AddressEvent +} + +func (event Event) String() string { + var parts []string + + if event.Refresh != 0 { + parts = append(parts, fmt.Sprintf("refresh: %v", event.Refresh)) + } + + if event.User != nil { + parts = append(parts, "user: [modified]") + } + + if event.MailSettings != nil { + parts = append(parts, "mail-settings: [modified]") + } + + if len(event.Messages) > 0 { + parts = append(parts, fmt.Sprintf( + "messages: created=%d, updated=%d, deleted=%d", + xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventCreate }), + xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }), + xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventDelete }), + )) + } + + if len(event.Labels) > 0 { + parts = append(parts, fmt.Sprintf( + "labels: created=%d, updated=%d, deleted=%d", + xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventCreate }), + xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }), + xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventDelete }), + )) + } + + if len(event.Addresses) > 0 { + parts = append(parts, fmt.Sprintf( + "addresses: created=%d, updated=%d, deleted=%d", + xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventCreate }), + xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }), + xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventDelete }), + )) + } + + return fmt.Sprintf("Event %s: %s", event.EventID, strings.Join(parts, ", ")) +} + +// merge combines this event with the other event (assumed to be newer!). +// TODO: Intelligent merging: if there are multiple EventUpdate(Flags) events, can we just take the latest one? +func (event *Event) merge(other Event) error { + event.EventID = other.EventID + + if other.User != nil { + event.User = other.User + } + + if other.MailSettings != nil { + event.MailSettings = other.MailSettings + } + + // For now, label events are simply appended. + event.Labels = append(event.Labels, other.Labels...) + + // For now, message events are simply appended. + event.Messages = append(event.Messages, other.Messages...) + + // For now, address events are simply appended. + event.Addresses = append(event.Addresses, other.Addresses...) + + return nil +} + +type RefreshFlag uint8 + +const ( + RefreshMail RefreshFlag = 1 << iota // 1<<0 = 1 + _ // 1<<1 = 2 + _ // 1<<2 = 4 + _ // 1<<3 = 8 + _ // 1<<4 = 16 + _ // 1<<5 = 32 + _ // 1<<6 = 64 + _ // 1<<7 = 128 + RefreshAll RefreshFlag = 1<5s, and we only allow 1s in the context. + // Thus, it will fail. + c := m.NewClient("", "", "", time.Now().Add(time.Hour)) + defer c.Close() + + if _, err := c.GetAddresses(ctx); err == nil { + t.Fatal("expected error, instead got", err) + } +} + +func TestReturnErrNoConnection(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + // We will fail more times than we retry, so requests should fail with ErrNoConnection. + m := proton.New( + proton.WithHostURL(ts.URL), + proton.WithRetryCount(5), + proton.WithTransport(newFailingRoundTripper(10)), + ) + + // The call should fail because every dial will fail and we'll run out of retries. + c := m.NewClient("", "", "", time.Now().Add(time.Hour)) + defer c.Close() + + if _, err := c.GetAddresses(context.Background()); err == nil { + t.Fatal("expected error, instead got", err) + } +} + +func TestStatusCallbacks(t *testing.T) { + s := server.New() + defer s.Close() + + ctl := proton.NewNetCtl() + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + ) + + statusCh := make(chan proton.Status, 1) + + m.AddStatusObserver(func(status proton.Status) { + statusCh <- status + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctl.Disable() + + require.Error(t, m.Ping(ctx)) + require.Equal(t, proton.StatusDown, <-statusCh) + + ctl.Enable() + + require.NoError(t, m.Ping(ctx)) + require.Equal(t, proton.StatusUp, <-statusCh) + + ctl.SetReadLimit(1) + + require.Error(t, m.Ping(ctx)) + require.Equal(t, proton.StatusDown, <-statusCh) + + ctl.SetReadLimit(0) + + require.NoError(t, m.Ping(ctx)) + require.Equal(t, proton.StatusUp, <-statusCh) +} + +type failingRoundTripper struct { + http.RoundTripper + + fails, calls int +} + +func newFailingRoundTripper(fails int) http.RoundTripper { + return &failingRoundTripper{ + RoundTripper: http.DefaultTransport, + fails: fails, + } +} + +func (rt *failingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.calls++ + + if rt.calls < rt.fails { + return nil, errors.New("simulating network error") + } + + return rt.RoundTripper.RoundTrip(req) +} diff --git a/message.go b/message.go new file mode 100644 index 0000000..5f9823d --- /dev/null +++ b/message.go @@ -0,0 +1,274 @@ +package proton + +import ( + "context" + "fmt" + "runtime" + "strconv" + + "github.com/bradenaw/juniper/iterator" + "github.com/bradenaw/juniper/parallel" + "github.com/bradenaw/juniper/stream" + "github.com/bradenaw/juniper/xslices" + "github.com/go-resty/resty/v2" +) + +const maxMessageIDs = 1000 + +func (c *Client) GetFullMessage(ctx context.Context, messageID string) (FullMessage, error) { + message, err := c.GetMessage(ctx, messageID) + if err != nil { + return FullMessage{}, err + } + + attData, err := c.attPool().ProcessAll(ctx, xslices.Map(message.Attachments, func(att Attachment) string { + return att.ID + })) + if err != nil { + return FullMessage{}, err + } + + return FullMessage{ + Message: message, + AttData: attData, + }, nil +} + +func (c *Client) GetFullMessages(ctx context.Context, workers, buffer int, messageIDs ...string) stream.Stream[FullMessage] { + return parallel.MapStream( + ctx, + stream.FromIterator(iterator.Slice(messageIDs)), + workers, + buffer, + c.GetFullMessage, + ) +} + +func (c *Client) GetMessage(ctx context.Context, messageID string) (Message, error) { + var res struct { + Message Message + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/mail/v4/messages/" + messageID) + }); err != nil { + return Message{}, err + } + + return res.Message, nil +} + +func (c *Client) CountMessages(ctx context.Context) (int, error) { + var res struct { + Total int + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetQueryParams(map[string]string{ + "Limit": strconv.Itoa(0), + }).SetResult(&res).Get("/mail/v4/messages") + }); err != nil { + return 0, err + } + + return res.Total, nil +} + +func (c *Client) GetMessageMetadata(ctx context.Context, filter MessageFilter) ([]MessageMetadata, error) { + total, err := c.CountMessages(ctx) + if err != nil { + return nil, err + } + + return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]MessageMetadata, error) { + return c.getMessageMetadata(ctx, page, pageSize, filter) + }) +} + +func (c *Client) GetMessageIDs(ctx context.Context, afterID string) ([]string, error) { + var messageIDs []string + + for ; ; afterID = messageIDs[len(messageIDs)-1] { + page, err := c.getMessageIDs(ctx, afterID) + if err != nil { + return nil, err + } + + if len(page) == 0 { + return messageIDs, nil + } + + messageIDs = append(messageIDs, page...) + } +} + +func (c *Client) DeleteMessage(ctx context.Context, messageIDs ...string) error { + pages := xslices.Chunk(messageIDs, maxPageSize) + + return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/delete") + }) + }) +} + +func (c *Client) MarkMessagesRead(ctx context.Context, messageIDs ...string) error { + pages := xslices.Chunk(messageIDs, maxPageSize) + + return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/read") + }) + }) +} + +func (c *Client) MarkMessagesUnread(ctx context.Context, messageIDs ...string) error { + pages := xslices.Chunk(messageIDs, maxPageSize) + + return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error { + req := MessageActionReq{IDs: pages[idx]} + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).Put("/mail/v4/messages/unread") + }); err != nil { + return err + } + + return nil + }) +} + +func (c *Client) LabelMessages(ctx context.Context, messageIDs []string, labelID string) error { + res, err := parallel.MapContext( + ctx, + runtime.NumCPU(), + xslices.Chunk(messageIDs, maxPageSize), + func(ctx context.Context, messageIDs []string) (LabelMessagesRes, error) { + var res LabelMessagesRes + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(LabelMessagesReq{ + LabelID: labelID, + IDs: messageIDs, + }).SetResult(&res).Put("/mail/v4/messages/label") + }); err != nil { + return LabelMessagesRes{}, err + } + + return res, nil + }, + ) + if err != nil { + return err + } + + if idx := xslices.IndexFunc(res, func(res LabelMessagesRes) bool { return !res.ok() }); idx >= 0 { + tokens := xslices.Map(res, func(res LabelMessagesRes) UndoToken { + return res.UndoToken + }) + + if _, undoErr := c.UndoActions(ctx, tokens...); undoErr != nil { + return fmt.Errorf("failed to undo actions: %w", undoErr) + } + + return fmt.Errorf("failed to label messages") + } + + return nil +} + +func (c *Client) UnlabelMessages(ctx context.Context, messageIDs []string, labelID string) error { + res, err := parallel.MapContext( + ctx, + runtime.NumCPU(), + xslices.Chunk(messageIDs, maxPageSize), + func(ctx context.Context, messageIDs []string) (LabelMessagesRes, error) { + var res LabelMessagesRes + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(LabelMessagesReq{ + LabelID: labelID, + IDs: messageIDs, + }).SetResult(&res).Put("/mail/v4/messages/unlabel") + }); err != nil { + return LabelMessagesRes{}, err + } + + return res, nil + }, + ) + if err != nil { + return err + } + + if idx := xslices.IndexFunc(res, func(res LabelMessagesRes) bool { return !res.ok() }); idx >= 0 { + tokens := xslices.Map(res, func(res LabelMessagesRes) UndoToken { + return res.UndoToken + }) + + if _, undoErr := c.UndoActions(ctx, tokens...); undoErr != nil { + return fmt.Errorf("failed to undo actions: %w", undoErr) + } + + return fmt.Errorf("failed to unlabel messages") + } + + return nil +} + +func (c *Client) getMessageIDs(ctx context.Context, afterID string) ([]string, error) { + var res struct { + IDs []string + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + if afterID != "" { + r = r.SetQueryParam("AfterID", afterID) + } + + return r.SetQueryParam("Limit", strconv.Itoa(maxMessageIDs)).SetResult(&res).Get("/mail/v4/messages/ids") + }); err != nil { + return nil, err + } + + return res.IDs, nil +} + +func (c *Client) getMessageMetadata(ctx context.Context, page, pageSize int, filter MessageFilter) ([]MessageMetadata, error) { + var res struct { + Messages []MessageMetadata + Stale Bool + } + + req := struct { + MessageFilter + + Page int + PageSize int + + Sort string + Desc Bool + }{ + MessageFilter: filter, + + Page: page, + PageSize: pageSize, + + Sort: "ID", + Desc: false, + } + + for { + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).SetHeader("X-HTTP-Method-Override", "GET").Post("/mail/v4/messages") + }); err != nil { + return nil, err + } + + if !res.Stale { + break + } + } + + return res.Messages, nil +} diff --git a/message_build.go b/message_build.go new file mode 100644 index 0000000..1a2e96d --- /dev/null +++ b/message_build.go @@ -0,0 +1,318 @@ +package proton + +import ( + "bufio" + "bytes" + "encoding/base64" + "io" + "mime" + "net/mail" + "strings" + "time" + "unicode/utf8" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/emersion/go-message" + "github.com/emersion/go-message/textproto" + "github.com/google/uuid" +) + +func BuildRFC822(kr *crypto.KeyRing, msg Message, attData map[string][]byte) ([]byte, error) { + if msg.MIMEType == rfc822.MultipartMixed { + return buildPGPRFC822(kr, msg) + } + + header, err := getMixedMessageHeader(msg) + if err != nil { + return nil, err + } + + buf := new(bytes.Buffer) + + w, err := message.CreateWriter(buf, header) + if err != nil { + return nil, err + } + + var ( + inlineAtts []Attachment + inlineData [][]byte + attachAtts []Attachment + attachData [][]byte + ) + + for _, att := range msg.Attachments { + if att.Disposition == InlineDisposition { + inlineAtts = append(inlineAtts, att) + inlineData = append(inlineData, attData[att.ID]) + } else { + attachAtts = append(attachAtts, att) + attachData = append(attachData, attData[att.ID]) + } + } + + if len(inlineAtts) > 0 { + if err := writeRelatedParts(w, kr, msg, inlineAtts, inlineData); err != nil { + return nil, err + } + } else if err := writeTextPart(w, kr, msg); err != nil { + return nil, err + } + + for i, att := range attachAtts { + if err := writeAttachmentPart(w, kr, att, attachData[i]); err != nil { + return nil, err + } + } + + if err := w.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func writeTextPart(w *message.Writer, kr *crypto.KeyRing, msg Message) error { + dec, err := msg.Decrypt(kr) + if err != nil { + return err + } + + part, err := w.CreatePart(getTextPartHeader(dec, msg.MIMEType)) + if err != nil { + return err + } + + if _, err := part.Write(dec); err != nil { + return err + } + + return part.Close() +} + +func writeAttachmentPart(w *message.Writer, kr *crypto.KeyRing, att Attachment, attData []byte) error { + kps, err := base64.StdEncoding.DecodeString(att.KeyPackets) + if err != nil { + return err + } + + msg := crypto.NewPGPSplitMessage(kps, attData).GetPGPMessage() + + dec, err := kr.Decrypt(msg, nil, crypto.GetUnixTime()) + if err != nil { + return err + } + + part, err := w.CreatePart(getAttachmentPartHeader(att)) + if err != nil { + return err + } + + if _, err := part.Write(dec.GetBinary()); err != nil { + return err + } + + return part.Close() +} + +func writeRelatedParts(w *message.Writer, kr *crypto.KeyRing, msg Message, atts []Attachment, attData [][]byte) error { + var header message.Header + + header.SetContentType(string(rfc822.MultipartRelated), nil) + + rel, err := w.CreatePart(header) + if err != nil { + return err + } + + if err := writeTextPart(rel, kr, msg); err != nil { + return err + } + + for i, att := range atts { + if err := writeAttachmentPart(rel, kr, att, attData[i]); err != nil { + return err + } + } + + return rel.Close() +} + +func buildPGPRFC822(kr *crypto.KeyRing, msg Message) ([]byte, error) { + raw, err := textproto.ReadHeader(bufio.NewReader(strings.NewReader(msg.Header))) + if err != nil { + return nil, err + } + + dec, err := msg.Decrypt(kr) + if err != nil { + return nil, err + } + + sigs, err := ExtractSignatures(kr, msg.Body) + if err != nil { + return nil, err + } + + if len(sigs) > 0 { + return buildMultipartSignedRFC822(message.Header{Header: raw}, dec, sigs[0]) + } + + return buildMultipartEncryptedRFC822(message.Header{Header: raw}, dec) +} + +func buildMultipartSignedRFC822(header message.Header, body []byte, sig Signature) ([]byte, error) { + buf := new(bytes.Buffer) + + boundary := uuid.New().String() + + header.SetContentType("multipart/signed", map[string]string{ + "micalg": sig.Hash, + "protocol": "application/pgp-signature", + "boundary": boundary, + }) + + if err := textproto.WriteHeader(buf, header.Header); err != nil { + return nil, err + } + + w := rfc822.NewMultipartWriter(buf, boundary) + + bodyHeader, bodyData := rfc822.Split(body) + + if err := w.AddPart(func(w io.Writer) error { + if _, err := w.Write(bodyHeader); err != nil { + return err + } + + if _, err := w.Write(bodyData); err != nil { + return err + } + + return nil + }); err != nil { + return nil, err + } + + var sigHeader message.Header + + sigHeader.SetContentType("application/pgp-signature", map[string]string{"name": "OpenPGP_signature.asc"}) + sigHeader.SetContentDisposition("attachment", map[string]string{"filename": "OpenPGP_signature"}) + sigHeader.Set("Content-Description", "OpenPGP digital signature") + + sigData, err := sig.Data.GetArmored() + if err != nil { + return nil, err + } + + if err := w.AddPart(func(w io.Writer) error { + if err := textproto.WriteHeader(w, sigHeader.Header); err != nil { + return err + } + + if _, err := w.Write([]byte(sigData)); err != nil { + return err + } + + return nil + }); err != nil { + return nil, err + } + + if err := w.Done(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func buildMultipartEncryptedRFC822(header message.Header, body []byte) ([]byte, error) { + buf := new(bytes.Buffer) + + bodyHeader, bodyData := rfc822.Split(body) + + parsedHeader, err := rfc822.NewHeader(bodyHeader) + if err != nil { + return nil, err + } + + parsedHeader.Entries(func(key, val string) { + header.Set(key, val) + }) + + if err := textproto.WriteHeader(buf, header.Header); err != nil { + return nil, err + } + + if _, err := buf.Write(bodyData); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func getMixedMessageHeader(msg Message) (message.Header, error) { + raw, err := textproto.ReadHeader(bufio.NewReader(strings.NewReader(msg.Header))) + if err != nil { + return message.Header{}, err + } + + header := message.Header{Header: raw} + + header.SetContentType(string(rfc822.MultipartMixed), nil) + + if date, err := mail.ParseDate(header.Get("Date")); err != nil || date.Before(time.Unix(0, 0)) { + if msgTime := time.Unix(msg.Time, 0); msgTime.After(time.Unix(0, 0)) { + header.Set("Date", msgTime.In(time.UTC).Format(time.RFC1123Z)) + } else { + header.Del("Date") + } + + header.Set("X-Original-Date", date.In(time.UTC).Format(time.RFC1123Z)) + } + + return header, nil +} + +func getTextPartHeader(body []byte, mimeType rfc822.MIMEType) message.Header { + var header message.Header + + params := make(map[string]string) + + if utf8.Valid(body) { + params["charset"] = "utf-8" + } + + header.SetContentType(string(mimeType), params) + + // Use quoted-printable for all text/... parts + header.Set("Content-Transfer-Encoding", "quoted-printable") + + return header +} + +func getAttachmentPartHeader(att Attachment) message.Header { + var header message.Header + + for key, val := range att.Headers { + for _, val := range val { + header.Add(key, val) + } + } + + // All attachments have a content type. + header.SetContentType(string(att.MIMEType), map[string]string{"name": mime.QEncoding.Encode("utf-8", att.Name)}) + + // All attachments have a content disposition. + header.SetContentDisposition(string(att.Disposition), map[string]string{"filename": mime.QEncoding.Encode("utf-8", att.Name)}) + + // Use base64 for all attachments except embedded RFC822 messages. + if att.MIMEType != rfc822.MessageRFC822 { + header.Set("Content-Transfer-Encoding", "base64") + } else { + header.Del("Content-Transfer-Encoding") + } + + return header +} diff --git a/message_draft_types.go b/message_draft_types.go new file mode 100644 index 0000000..9a52527 --- /dev/null +++ b/message_draft_types.go @@ -0,0 +1,37 @@ +package proton + +import ( + "net/mail" + + "github.com/ProtonMail/gluon/rfc822" +) + +type DraftTemplate struct { + Subject string + Sender *mail.Address + ToList []*mail.Address + CCList []*mail.Address + BCCList []*mail.Address + Body string + MIMEType rfc822.MIMEType + + ExternalID string `json:",omitempty"` +} + +type CreateDraftAction int + +const ( + ReplyAction CreateDraftAction = iota + ReplyAllAction + ForwardAction + AutoResponseAction + ReadReceiptAction +) + +type CreateDraftReq struct { + Message DraftTemplate + AttachmentKeyPackets []string + + ParentID string `json:",omitempty"` + Action CreateDraftAction +} diff --git a/message_encrypt.go b/message_encrypt.go new file mode 100644 index 0000000..729d07c --- /dev/null +++ b/message_encrypt.go @@ -0,0 +1,123 @@ +package proton + +import ( + "bytes" + "io" + "mime" + "strings" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/google/uuid" +) + +// CharsetReader returns a charset decoder for the given charset. +// If set, it will be used to decode non-utf8 encoded messages. +var CharsetReader func(charset string, input io.Reader) (io.Reader, error) + +// EncryptRFC822 encrypts the given message literal as a PGP attachment. +func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) { + msg, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr) + if err != nil { + return nil, err + } + armored, err := msg.GetArmored() + if err != nil { + return nil, err + } + + header, _ := rfc822.Split(literal) + + headerParsed, err := rfc822.NewHeader(header) + if err != nil { + return nil, err + } + + buf := new(bytes.Buffer) + boundary := strings.ReplaceAll(uuid.NewString(), "-", "") + multipartWriter := rfc822.NewMultipartWriter(buf, boundary) + + { + newHeader := rfc822.NewEmptyHeader() + + if value, ok := headerParsed.GetChecked("Message-Id"); ok { + newHeader.Set("Message-Id", value) + } + + contentType := mime.FormatMediaType("multipart/encrypted", map[string]string{ + "boundary": boundary, + "protocol": "application/pgp-encrypted", + }) + newHeader.Set("Mime-version", "1.0") + newHeader.Set("Content-Type", contentType) + + if value, ok := headerParsed.GetChecked("From"); ok { + newHeader.Set("From", value) + } + + if value, ok := headerParsed.GetChecked("To"); ok { + newHeader.Set("To", value) + } + + if value, ok := headerParsed.GetChecked("Subject"); ok { + newHeader.Set("Subject", value) + } + + if value, ok := headerParsed.GetChecked("Date"); ok { + newHeader.Set("Date", value) + } + + if value, ok := headerParsed.GetChecked("Received"); ok { + newHeader.Set("Received", value) + } + + buf.Write(newHeader.Raw()) + } + + // Write PGP control data + { + pgpControlHeader := rfc822.NewEmptyHeader() + pgpControlHeader.Set("Content-Description", "PGP/MIME version identification") + pgpControlHeader.Set("Content-Type", "application/pgp-encrypted") + if err := multipartWriter.AddPart(func(writer io.Writer) error { + if _, err := writer.Write(pgpControlHeader.Raw()); err != nil { + return err + } + + _, err := writer.Write([]byte("Version: 1")) + + return err + }); err != nil { + return nil, err + } + } + + // write PGP attachment + { + pgpAttachmentHeader := rfc822.NewEmptyHeader() + contentType := mime.FormatMediaType("application/octet-stream", map[string]string{ + "name": "encrypted.asc", + }) + pgpAttachmentHeader.Set("Content-Description", "OpenPGP encrypted message") + pgpAttachmentHeader.Set("Content-Disposition", "inline; filename=encrypted.asc") + pgpAttachmentHeader.Set("Content-Type", contentType) + + if err := multipartWriter.AddPart(func(writer io.Writer) error { + if _, err := writer.Write(pgpAttachmentHeader.Raw()); err != nil { + return err + } + + _, err := writer.Write([]byte(armored)) + return err + }); err != nil { + return nil, err + } + } + + // finish messsage + if err := multipartWriter.Done(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/message_encrypt_test.go b/message_encrypt_test.go new file mode 100644 index 0000000..d9b37f6 --- /dev/null +++ b/message_encrypt_test.go @@ -0,0 +1,96 @@ +package proton + +import ( + "bytes" + "testing" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptMessage(t *testing.T) { + const message = `From: Nathaniel Borenstein +To: Ned Freed +Subject: Sample message (import 2) +MIME-Version: 1.0 +Content-type: multipart/mixed; boundary="simple boundary" +Received: from mail.protonmail.ch by mail.protonmail.ch; Tue, 25 Nov 2016 + +This is the preamble. It is to be ignored, though it +is a handy place for mail composers to include an +explanatory note to non-MIME compliant readers. +--simple boundary + +This is implicitly typed plain ASCII text. +It does NOT end with a linebreak. +--simple boundary +Content-type: text/plain; charset=us-ascii + +This is explicitly typed plain ASCII text. +It DOES end with a linebreak. + +--simple boundary-- +This is the epilogue. It is also to be ignored. +` + key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0) + require.NoError(t, err) + + kr, err := crypto.NewKeyRing(key) + require.NoError(t, err) + + encryptedMessage, err := EncryptRFC822(kr, []byte(message)) + require.NoError(t, err) + + section := rfc822.Parse(encryptedMessage) + + { + // Check root header: + header, err := section.ParseHeader() + require.NoError(t, err) + + assert.Equal(t, header.Get("From"), "Nathaniel Borenstein ") + assert.Equal(t, header.Get("To"), "Ned Freed ") + assert.Equal(t, header.Get("Subject"), "Sample message (import 2)") + assert.Equal(t, header.Get("MIME-Version"), "1.0") + assert.Equal(t, header.Get("Received"), "from mail.protonmail.ch by mail.protonmail.ch; Tue, 25 Nov 2016") + + mediaType, params, err := rfc822.ParseMediaType(header.Get("Content-Type")) + require.NoError(t, err) + assert.Equal(t, "multipart/encrypted", mediaType) + assert.Equal(t, "application/pgp-encrypted", params["protocol"]) + assert.NotEmpty(t, params["boundary"]) + } + + children, err := section.Children() + require.NoError(t, err) + require.Equal(t, 2, len(children)) + + { + // check first child. + child := children[0] + header, err := child.ParseHeader() + require.NoError(t, err) + + assert.Equal(t, header.Get("Content-Description"), "PGP/MIME version identification") + assert.Equal(t, header.Get("Content-Type"), "application/pgp-encrypted") + + assert.Equal(t, []byte("Version: 1"), child.Body()) + } + + { + // check second child. + child := children[1] + header, err := child.ParseHeader() + require.NoError(t, err) + + assert.Equal(t, header.Get("Content-Description"), "OpenPGP encrypted message") + assert.Equal(t, header.Get("Content-Disposition"), "inline; filename=encrypted.asc") + assert.Equal(t, header.Get("Content-type"), "application/octet-stream; name=encrypted.asc") + + body := child.Body() + assert.True(t, bytes.HasPrefix(body, []byte("-----BEGIN PGP MESSAGE-----"))) + assert.True(t, bytes.HasSuffix(body, []byte("-----END PGP MESSAGE-----"))) + } +} diff --git a/message_import.go b/message_import.go new file mode 100644 index 0000000..dec827c --- /dev/null +++ b/message_import.go @@ -0,0 +1,84 @@ +package proton + +import ( + "context" + "fmt" + "strconv" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/iterator" + "github.com/bradenaw/juniper/parallel" + "github.com/bradenaw/juniper/stream" + "github.com/bradenaw/juniper/xslices" + "github.com/go-resty/resty/v2" +) + +const maxImportSize = 10 + +func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, workers, buffer int, req ...ImportReq) stream.Stream[ImportRes] { + return stream.Flatten(parallel.MapStream( + ctx, + stream.FromIterator(iterator.Chunk(iterator.Slice(req), maxImportSize)), + workers, + buffer, + func(ctx context.Context, req []ImportReq) (stream.Stream[ImportRes], error) { + res, err := c.importMessages(ctx, addrKR, req) + if err != nil { + return nil, fmt.Errorf("failed to import messages: %w", err) + } + + for _, res := range res { + if res.Code != SuccessCode { + return nil, fmt.Errorf("failed to import message: %w", res.Error) + } + } + + return stream.FromIterator(iterator.Slice(res)), nil + }, + )) +} + +func (c *Client) importMessages(ctx context.Context, addrKR *crypto.KeyRing, req []ImportReq) ([]ImportRes, error) { + names := iterator.Collect(iterator.Map(iterator.Counter(len(req)), func(i int) string { + return strconv.Itoa(i) + })) + + var named []namedImportReq + + for idx, name := range names { + named = append(named, namedImportReq{ + ImportReq: req[idx], + Name: name, + }) + } + + fields, err := buildImportReqFields(addrKR, named) + if err != nil { + return nil, err + } + + type namedImportRes struct { + Name string + Response ImportRes + } + + var res struct { + Responses []namedImportRes + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetMultipartFields(fields...).SetResult(&res).Post("/mail/v4/messages/import") + }); err != nil { + return nil, err + } + + namedRes := make(map[string]ImportRes, len(res.Responses)) + + for _, res := range res.Responses { + namedRes[res.Name] = res.Response + } + + return xslices.Map(names, func(name string) ImportRes { + return namedRes[name] + }), nil +} diff --git a/message_import_types.go b/message_import_types.go new file mode 100644 index 0000000..f60fdbb --- /dev/null +++ b/message_import_types.go @@ -0,0 +1,68 @@ +package proton + +import ( + "bytes" + "encoding/json" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/go-resty/resty/v2" +) + +type ImportReq struct { + Metadata ImportMetadata + Message []byte +} + +type namedImportReq struct { + ImportReq + + Name string +} + +type ImportMetadata struct { + AddressID string + LabelIDs []string + Unread Bool + Flags MessageFlag +} + +type ImportRes struct { + Error + MessageID string +} + +func buildImportReqFields(addrKR *crypto.KeyRing, req []namedImportReq) ([]*resty.MultipartField, error) { + var fields []*resty.MultipartField + + metadata := make(map[string]ImportMetadata) + + for _, req := range req { + metadata[req.Name] = req.Metadata + + enc, err := EncryptRFC822(addrKR, req.Message) + if err != nil { + return nil, err + } + + fields = append(fields, &resty.MultipartField{ + Param: req.Name, + FileName: req.Name + ".eml", + ContentType: string(rfc822.MessageRFC822), + Reader: bytes.NewReader(append(enc, "\r\n"...)), + }) + } + + b, err := json.Marshal(metadata) + if err != nil { + return nil, err + } + + fields = append(fields, &resty.MultipartField{ + Param: "Metadata", + ContentType: "application/json", + Reader: bytes.NewReader(b), + }) + + return fields, nil +} diff --git a/message_send.go b/message_send.go new file mode 100644 index 0000000..9bdc9ff --- /dev/null +++ b/message_send.go @@ -0,0 +1,35 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) CreateDraft(ctx context.Context, req CreateDraftReq) (Message, error) { + var res struct { + Message Message + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages") + }); err != nil { + return Message{}, err + } + + return res.Message, nil +} + +func (c *Client) SendDraft(ctx context.Context, draftID string, req SendDraftReq) (Message, error) { + var res struct { + Sent Message + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages/" + draftID) + }); err != nil { + return Message{}, err + } + + return res.Sent, nil +} diff --git a/message_send_types.go b/message_send_types.go new file mode 100644 index 0000000..69ed1f4 --- /dev/null +++ b/message_send_types.go @@ -0,0 +1,308 @@ +package proton + +import ( + "encoding/base64" + "fmt" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" +) + +type EncryptionScheme int + +const ( + InternalScheme EncryptionScheme = 1 << iota + EncryptedOutsideScheme + ClearScheme + PGPInlineScheme + PGPMIMEScheme + ClearMIMEScheme +) + +type SignatureType int + +const ( + NoSignature SignatureType = iota + DetachedSignature + AttachedSignature +) + +type MessageRecipient struct { + Type EncryptionScheme + Signature SignatureType + + BodyKeyPacket string `json:",omitempty"` + AttachmentKeyPackets map[string]string `json:",omitempty"` +} + +type MessagePackage struct { + Addresses map[string]*MessageRecipient + MIMEType rfc822.MIMEType + Type EncryptionScheme + Body string + + BodyKey *SessionKey `json:",omitempty"` + AttachmentKeys map[string]*SessionKey `json:",omitempty"` +} + +func newMessagePackage(mimeType rfc822.MIMEType, encBodyData []byte) *MessagePackage { + return &MessagePackage{ + Addresses: make(map[string]*MessageRecipient), + MIMEType: mimeType, + Body: base64.StdEncoding.EncodeToString(encBodyData), + + AttachmentKeys: make(map[string]*SessionKey), + } +} + +type SessionKey struct { + Key string + Algorithm string +} + +func newSessionKey(key *crypto.SessionKey) *SessionKey { + return &SessionKey{ + Key: key.GetBase64Key(), + Algorithm: key.Algo, + } +} + +type SendPreferences struct { + // Encrypt indicates whether the email should be encrypted or not. + // If it's encrypted, we need to know which public key to use. + Encrypt bool + + // PubKey contains an OpenPGP key that can be used for encryption. + PubKey *crypto.KeyRing + + // SignatureType indicates how the email should be signed. + SignatureType SignatureType + + // EncryptionScheme indicates if we should encrypt body and attachments separately and + // what MIME format to give the final encrypted email. The two standard PGP + // schemes are PGP/MIME and PGP/Inline. However we use a custom scheme for + // internal emails (including the so-called encrypted-to-outside emails, + // which even though meant for external users, they don't really get out of + // our platform). If the email is sent unencrypted, no PGP scheme is needed. + EncryptionScheme EncryptionScheme + + // MIMEType is the MIME type to use for formatting the body of the email + // (before encryption/after decryption). The standard possibilities are the + // enriched HTML format, text/html, and plain text, text/plain. But it's + // also possible to have a multipart/mixed format, which is typically used + // for PGP/MIME encrypted emails, where attachments go into the body too. + // Because of this, this option is sometimes called MIME format. + MIMEType rfc822.MIMEType +} + +type SendDraftReq struct { + Packages []*MessagePackage +} + +func (req *SendDraftReq) AddMIMEPackage( + kr *crypto.KeyRing, + mimeBody string, + prefs map[string]SendPreferences, +) error { + for _, prefs := range prefs { + if prefs.MIMEType != rfc822.MultipartMixed { + return fmt.Errorf("invalid MIME type for MIME package: %s", prefs.MIMEType) + } + } + + pkg, err := newMIMEPackage(kr, mimeBody, prefs) + if err != nil { + return err + } + + req.Packages = append(req.Packages, pkg) + + return nil +} + +func (req *SendDraftReq) AddTextPackage( + kr *crypto.KeyRing, + body string, + mimeType rfc822.MIMEType, + prefs map[string]SendPreferences, + attKeys map[string]*crypto.SessionKey, +) error { + pkg, err := newTextPackage(kr, body, mimeType, prefs, attKeys) + if err != nil { + return err + } + + req.Packages = append(req.Packages, pkg) + + return nil +} + +func newMIMEPackage( + kr *crypto.KeyRing, + mimeBody string, + prefs map[string]SendPreferences, +) (*MessagePackage, error) { + decBodyKey, encBodyData, err := encSplit(kr, mimeBody) + if err != nil { + return nil, fmt.Errorf("failed to encrypt MIME body: %w", err) + } + + pkg := newMessagePackage(rfc822.MultipartMixed, encBodyData) + + for addr, prefs := range prefs { + if prefs.MIMEType != rfc822.MultipartMixed { + return nil, fmt.Errorf("invalid MIME type for MIME package: %s", prefs.MIMEType) + } + + if prefs.SignatureType != DetachedSignature { + return nil, fmt.Errorf("invalid signature type for MIME package: %d", prefs.SignatureType) + } + + recipient := &MessageRecipient{ + Type: prefs.EncryptionScheme, + Signature: prefs.SignatureType, + } + + switch prefs.EncryptionScheme { + case PGPMIMEScheme: + if prefs.PubKey == nil { + return nil, fmt.Errorf("missing public key for %s", addr) + } + + encBodyKey, err := prefs.PubKey.EncryptSessionKey(decBodyKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt session key: %w", err) + } + + recipient.BodyKeyPacket = base64.StdEncoding.EncodeToString(encBodyKey) + + case ClearMIMEScheme: + pkg.BodyKey = &SessionKey{ + Key: decBodyKey.GetBase64Key(), + Algorithm: decBodyKey.Algo, + } + + default: + return nil, fmt.Errorf("invalid encryption scheme for MIME package: %d", prefs.EncryptionScheme) + } + + pkg.Addresses[addr] = recipient + pkg.Type |= prefs.EncryptionScheme + } + + return pkg, nil +} + +func newTextPackage( + kr *crypto.KeyRing, + body string, + mimeType rfc822.MIMEType, + prefs map[string]SendPreferences, + attKeys map[string]*crypto.SessionKey, +) (*MessagePackage, error) { + if mimeType != rfc822.TextPlain && mimeType != rfc822.TextHTML { + return nil, fmt.Errorf("invalid MIME type for package: %s", mimeType) + } + + decBodyKey, encBodyData, err := encSplit(kr, body) + if err != nil { + return nil, fmt.Errorf("failed to encrypt message body: %w", err) + } + + pkg := newMessagePackage(mimeType, encBodyData) + + for addr, prefs := range prefs { + if prefs.MIMEType != mimeType { + return nil, fmt.Errorf("invalid MIME type for package: %s", prefs.MIMEType) + } + + if prefs.SignatureType == DetachedSignature && !prefs.Encrypt { + if prefs.EncryptionScheme == PGPInlineScheme { + return nil, fmt.Errorf("invalid encryption scheme for %s: %d", addr, prefs.EncryptionScheme) + } + + if prefs.EncryptionScheme == ClearScheme && mimeType != rfc822.TextPlain { + return nil, fmt.Errorf("invalid MIME type for clear package: %s", mimeType) + } + } + + if prefs.EncryptionScheme == InternalScheme && !prefs.Encrypt { + return nil, fmt.Errorf("internal packages must be encrypted") + } + + if prefs.EncryptionScheme == PGPInlineScheme && mimeType != rfc822.TextPlain { + return nil, fmt.Errorf("invalid MIME type for PGP inline package: %s", mimeType) + } + + switch prefs.EncryptionScheme { + case ClearScheme: + pkg.BodyKey = newSessionKey(decBodyKey) + + for attID, attKey := range attKeys { + pkg.AttachmentKeys[attID] = newSessionKey(attKey) + } + + case InternalScheme, PGPInlineScheme: + // ... + + default: + return nil, fmt.Errorf("invalid encryption scheme for package: %d", prefs.EncryptionScheme) + } + + recipient := &MessageRecipient{ + Type: prefs.EncryptionScheme, + Signature: prefs.SignatureType, + AttachmentKeyPackets: make(map[string]string), + } + + if prefs.Encrypt { + if prefs.PubKey == nil { + return nil, fmt.Errorf("missing public key for %s", addr) + } + + if prefs.SignatureType != DetachedSignature { + return nil, fmt.Errorf("invalid signature type for package: %d", prefs.SignatureType) + } + + encBodyKey, err := prefs.PubKey.EncryptSessionKey(decBodyKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt session key: %w", err) + } + + recipient.BodyKeyPacket = base64.StdEncoding.EncodeToString(encBodyKey) + + for attID, attKey := range attKeys { + encAttKey, err := prefs.PubKey.EncryptSessionKey(attKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt attachment key: %w", err) + } + + recipient.AttachmentKeyPackets[attID] = base64.StdEncoding.EncodeToString(encAttKey) + } + } + + pkg.Addresses[addr] = recipient + pkg.Type |= prefs.EncryptionScheme + } + + return pkg, nil +} + +func encSplit(kr *crypto.KeyRing, body string) (*crypto.SessionKey, []byte, error) { + encBody, err := kr.Encrypt(crypto.NewPlainMessageFromString(body), kr) + if err != nil { + return nil, nil, fmt.Errorf("failed to encrypt MIME body: %w", err) + } + + splitEncBody, err := encBody.SplitMessage() + if err != nil { + return nil, nil, fmt.Errorf("failed to split message: %w", err) + } + + decBodyKey, err := kr.DecryptSessionKey(splitEncBody.GetBinaryKeyPacket()) + if err != nil { + return nil, nil, fmt.Errorf("failed to decrypt session key: %w", err) + } + + return decBodyKey, splitEncBody.GetBinaryDataPacket(), nil +} diff --git a/message_send_types_test.go b/message_send_types_test.go new file mode 100644 index 0000000..8a294c0 --- /dev/null +++ b/message_send_types_test.go @@ -0,0 +1,381 @@ +package proton + +import ( + "testing" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/stretchr/testify/require" +) + +func TestSendDraftReq_AddMIMEPackage(t *testing.T) { + key, err := crypto.GenerateKey("name", "email", "rsa", 2048) + require.NoError(t, err) + + kr, err := crypto.NewKeyRing(key) + require.NoError(t, err) + + tests := []struct { + name string + mimeBody string + prefs map[string]SendPreferences + wantErr bool + }{ + { + name: "Clear MIME with detached signature", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-sign@email.com": { + Encrypt: false, + SignatureType: DetachedSignature, + EncryptionScheme: ClearMIMEScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: false, + }, + { + name: "Clear MIME with no signature (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-no-sign@email.com": { + Encrypt: false, + SignatureType: NoSignature, + EncryptionScheme: ClearMIMEScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: true, + }, + { + name: "Clear MIME with plain text (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-plain@email.com": { + Encrypt: false, + SignatureType: DetachedSignature, + EncryptionScheme: ClearMIMEScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: true, + }, + { + name: "Clear MIME with rich text (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-html@email.com": { + Encrypt: false, + SignatureType: DetachedSignature, + EncryptionScheme: ClearMIMEScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "PGP MIME with detached signature", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-encrypted@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: PGPMIMEScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: false, + }, + { + name: "PGP MIME with plain text (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-encrypted-plain@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: PGPMIMEScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: true, + }, + { + name: "PGP MIME with rich text (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-encrypted-plain@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: PGPMIMEScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "PGP MIME with missing public key (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-encrypted-no-pubkey@email.com": { + Encrypt: true, + SignatureType: DetachedSignature, + EncryptionScheme: PGPMIMEScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: true, + }, + { + name: "PGP MIME with no signature (error)", + mimeBody: "this is a mime body", + prefs: map[string]SendPreferences{"mime-encrypted-no-signature@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: NoSignature, + EncryptionScheme: PGPMIMEScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req SendDraftReq + + if err := req.AddMIMEPackage(kr, tt.mimeBody, tt.prefs); (err != nil) != tt.wantErr { + t.Errorf("SendDraftReq.AddMIMEPackage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSendDraftReq_AddPackage(t *testing.T) { + key, err := crypto.GenerateKey("name", "email", "rsa", 2048) + require.NoError(t, err) + + kr, err := crypto.NewKeyRing(key) + require.NoError(t, err) + + tests := []struct { + name string + body string + mimeType rfc822.MIMEType + prefs map[string]SendPreferences + attKeys map[string]*crypto.SessionKey + wantErr bool + }{ + { + name: "internal plain text with detached signature", + body: "this is a text/plain body", + mimeType: rfc822.TextPlain, + prefs: map[string]SendPreferences{"internal-plain@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: false, + }, + { + name: "internal rich text with detached signature", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"internal-html@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: false, + }, + { + name: "internal rich text with bad package content type (error)", + body: "this is a text/html body", + mimeType: "bad content type", + prefs: map[string]SendPreferences{"internal-bad-package-content-type@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "internal rich text with bad recipient content type (error)", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"internal-bad-recipient-content-type@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: "bad content type", + }}, + wantErr: true, + }, + { + name: "internal with multipart (error)", + body: "this is a text/html body", + mimeType: rfc822.MultipartMixed, + prefs: map[string]SendPreferences{"internal-multipart-mixed@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: true, + }, + { + name: "internal without encryption (error)", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"internal-no-encrypt@email.com": { + Encrypt: false, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "internal without pubkey (error)", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"internal-no-pubkey@email.com": { + Encrypt: true, + SignatureType: DetachedSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "internal without signature (error)", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"internal-no-sig@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: NoSignature, + EncryptionScheme: InternalScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "clear rich text without signature", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"clear-rich@email.com": { + Encrypt: false, + SignatureType: NoSignature, + EncryptionScheme: ClearScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: false, + }, + { + name: "clear plain text without signature", + body: "this is a text/plain body", + mimeType: rfc822.TextPlain, + prefs: map[string]SendPreferences{"clear-plain@email.com": { + Encrypt: false, + SignatureType: NoSignature, + EncryptionScheme: ClearScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: false, + }, + { + name: "clear plain text with signature", + body: "this is a text/plain body", + mimeType: rfc822.TextPlain, + prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": { + Encrypt: false, + SignatureType: DetachedSignature, + EncryptionScheme: ClearScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: false, + }, + { + name: "clear plain text with bad scheme (error)", + body: "this is a text/plain body", + mimeType: rfc822.TextPlain, + prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": { + Encrypt: false, + SignatureType: DetachedSignature, + EncryptionScheme: PGPInlineScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: true, + }, + { + name: "clear rich text with signature (error)", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": { + Encrypt: false, + SignatureType: DetachedSignature, + EncryptionScheme: ClearScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "encrypted plain text with signature", + body: "this is a text/plain body", + mimeType: rfc822.TextPlain, + prefs: map[string]SendPreferences{"pgp-inline-with-sig@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: PGPInlineScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: false, + }, + { + name: "encrypted html text with signature (error)", + body: "this is a text/html body", + mimeType: rfc822.TextHTML, + prefs: map[string]SendPreferences{"pgp-inline-rich-with-sig@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: PGPInlineScheme, + MIMEType: rfc822.TextHTML, + }}, + wantErr: true, + }, + { + name: "encrypted mixed text with signature (error)", + body: "this is a multipart/mixed body", + mimeType: rfc822.MultipartMixed, + prefs: map[string]SendPreferences{"pgp-inline-mixed-with-sig@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: PGPInlineScheme, + MIMEType: rfc822.MultipartMixed, + }}, + wantErr: true, + }, + { + name: "encrypted for outside (error)", + body: "this is a text/plain body", + mimeType: rfc822.TextPlain, + prefs: map[string]SendPreferences{"enc-for-outside@email.com": { + Encrypt: true, + PubKey: kr, + SignatureType: DetachedSignature, + EncryptionScheme: EncryptedOutsideScheme, + MIMEType: rfc822.TextPlain, + }}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var req SendDraftReq + + if err := req.AddTextPackage(kr, tt.body, tt.mimeType, tt.prefs, tt.attKeys); (err != nil) != tt.wantErr { + t.Errorf("SendDraftReq.AddPackage() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/message_types.go b/message_types.go new file mode 100644 index 0000000..217c2af --- /dev/null +++ b/message_types.go @@ -0,0 +1,192 @@ +package proton + +import ( + "net/mail" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "golang.org/x/exp/slices" +) + +type MessageMetadata struct { + ID string + AddressID string + LabelIDs []string + ExternalID string + + Subject string + Sender *mail.Address + ToList []*mail.Address + CCList []*mail.Address + BCCList []*mail.Address + ReplyTos []*mail.Address + + Flags MessageFlag + Time int64 + Size int + Unread Bool + IsReplied Bool + IsRepliedAll Bool + IsForwarded Bool +} + +func (meta MessageMetadata) Seen() bool { + return !bool(meta.Unread) +} + +func (meta MessageMetadata) Starred() bool { + return slices.Contains(meta.LabelIDs, StarredLabel) +} + +func (meta MessageMetadata) IsDraft() bool { + return meta.Flags&(MessageFlagReceived|MessageFlagSent) == 0 +} + +type MessageFilter struct { + ID []string `json:",omitempty"` + + AddressID string `json:",omitempty"` + ExternalID string `json:",omitempty"` + LabelID string `json:",omitempty"` +} + +type Message struct { + MessageMetadata + + Header string + ParsedHeaders Headers + Body string + MIMEType rfc822.MIMEType + Attachments []Attachment +} + +type MessageFlag int64 + +const ( + MessageFlagReceived MessageFlag = 1 << iota + MessageFlagSent + MessageFlagInternal + MessageFlagE2E + MessageFlagAuto + MessageFlagReplied + MessageFlagRepliedAll + MessageFlagForwarded + MessageFlagAutoReplied + MessageFlagImported + MessageFlagOpened + MessageFlagReceiptSent + MessageFlagNotified + MessageFlagTouched + MessageFlagReceipt + MessageFlagProton + MessageFlagReceiptRequest + MessageFlagPublicKey + MessageFlagSign + MessageFlagUnsubscribed + MessageFlagSPFFail + MessageFlagDKIMFail + MessageFlagDMARCFail + MessageFlagHamManual + MessageFlagSpamAuto + MessageFlagSpamManual + MessageFlagPhishingAuto + MessageFlagPhishingManual +) + +func (f MessageFlag) Has(flag MessageFlag) bool { + return f&flag != 0 +} + +func (f MessageFlag) Matches(flag MessageFlag) bool { + return f&flag == flag +} + +func (f MessageFlag) HasAny(flags ...MessageFlag) bool { + for _, flag := range flags { + if f.Has(flag) { + return true + } + } + + return false +} + +func (f MessageFlag) HasAll(flags ...MessageFlag) bool { + for _, flag := range flags { + if !f.Has(flag) { + return false + } + } + + return true +} + +func (f MessageFlag) Add(flag MessageFlag) MessageFlag { + return f | flag +} + +func (f MessageFlag) Remove(flag MessageFlag) MessageFlag { + return f &^ flag +} + +func (f MessageFlag) Toggle(flag MessageFlag) MessageFlag { + if f.Has(flag) { + return f.Remove(flag) + } + + return f.Add(flag) +} + +func (m Message) Decrypt(kr *crypto.KeyRing) ([]byte, error) { + enc, err := crypto.NewPGPMessageFromArmored(m.Body) + if err != nil { + return nil, err + } + + dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime()) + if err != nil { + return nil, err + } + + return dec.GetBinary(), nil +} + +type FullMessage struct { + Message + + AttData [][]byte +} + +type Signature struct { + Hash string + Data *crypto.PGPSignature +} + +type MessageActionReq struct { + IDs []string +} + +type LabelMessagesReq struct { + LabelID string + IDs []string +} + +type LabelMessagesRes struct { + Responses []LabelMessageRes + UndoToken UndoToken +} + +func (res LabelMessagesRes) ok() bool { + for _, resp := range res.Responses { + if resp.Response.Code != SuccessCode { + return false + } + } + + return true +} + +type LabelMessageRes struct { + ID string + Response Error +} diff --git a/message_types_test.go b/message_types_test.go new file mode 100644 index 0000000..6d23479 --- /dev/null +++ b/message_types_test.go @@ -0,0 +1,49 @@ +package proton + +import ( + "os" + "testing" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/stretchr/testify/require" +) + +func TestDecrypt(t *testing.T) { + body, err := os.ReadFile("testdata/body.pgp") + require.NoError(t, err) + + pubKR := loadKeyRing(t, "testdata/pub.asc", nil) + prvKR := loadKeyRing(t, "testdata/prv.asc", []byte("password")) + + msg := Message{Body: string(body)} + + sigs, err := ExtractSignatures(prvKR, msg.Body) + require.NoError(t, err) + + enc, err := crypto.NewPGPMessageFromArmored(msg.Body) + require.NoError(t, err) + + dec, err := prvKR.Decrypt(enc, nil, crypto.GetUnixTime()) + require.NoError(t, err) + require.NoError(t, pubKR.VerifyDetached(dec, sigs[0].Data, crypto.GetUnixTime())) +} + +func loadKeyRing(t *testing.T, file string, pass []byte) *crypto.KeyRing { + f, err := os.Open(file) + require.NoError(t, err) + + defer f.Close() + + key, err := crypto.NewKeyFromArmoredReader(f) + require.NoError(t, err) + + if pass != nil { + key, err = key.Unlock(pass) + require.NoError(t, err) + } + + kr, err := crypto.NewKeyRing(key) + require.NoError(t, err) + + return kr +} diff --git a/option.go b/option.go new file mode 100644 index 0000000..31902ac --- /dev/null +++ b/option.go @@ -0,0 +1,138 @@ +package proton + +import ( + "net/http" + + "github.com/go-resty/resty/v2" +) + +// Option represents a type that can be used to configure the manager. +type Option interface { + config(*managerBuilder) +} + +func WithHostURL(hostURL string) Option { + return &withHostURL{ + hostURL: hostURL, + } +} + +type withHostURL struct { + hostURL string +} + +func (opt withHostURL) config(builder *managerBuilder) { + builder.hostURL = opt.hostURL +} + +func WithAppVersion(appVersion string) Option { + return &withAppVersion{ + appVersion: appVersion, + } +} + +type withAppVersion struct { + appVersion string +} + +func (opt withAppVersion) config(builder *managerBuilder) { + builder.appVersion = opt.appVersion +} + +func WithTransport(transport http.RoundTripper) Option { + return &withTransport{ + transport: transport, + } +} + +type withTransport struct { + transport http.RoundTripper +} + +func (opt withTransport) config(builder *managerBuilder) { + builder.transport = opt.transport +} + +type withAttPoolSize struct { + attPoolSize int +} + +func (opt withAttPoolSize) config(builder *managerBuilder) { + builder.attPoolSize = opt.attPoolSize +} + +func WithAttPoolSize(attPoolSize int) Option { + return &withAttPoolSize{ + attPoolSize: attPoolSize, + } +} + +type withSkipVerifyProofs struct { + skipVerifyProofs bool +} + +func (opt withSkipVerifyProofs) config(builder *managerBuilder) { + builder.verifyProofs = !opt.skipVerifyProofs +} + +func WithSkipVerifyProofs() Option { + return &withSkipVerifyProofs{ + skipVerifyProofs: true, + } +} + +func WithRetryCount(retryCount int) Option { + return &withRetryCount{ + retryCount: retryCount, + } +} + +type withRetryCount struct { + retryCount int +} + +func (opt withRetryCount) config(builder *managerBuilder) { + builder.retryCount = opt.retryCount +} + +func WithCookieJar(jar http.CookieJar) Option { + return &withCookieJar{ + jar: jar, + } +} + +type withCookieJar struct { + jar http.CookieJar +} + +func (opt withCookieJar) config(builder *managerBuilder) { + builder.cookieJar = opt.jar +} + +func WithLogger(logger resty.Logger) Option { + return &withLogger{ + logger: logger, + } +} + +type withLogger struct { + logger resty.Logger +} + +func (opt withLogger) config(builder *managerBuilder) { + builder.logger = opt.logger +} + +func WithDebug(debug bool) Option { + return &withDebug{ + debug: debug, + } +} + +type withDebug struct { + debug bool +} + +func (opt withDebug) config(builder *managerBuilder) { + builder.debug = opt.debug +} diff --git a/package.go b/package.go new file mode 100644 index 0000000..12f2e7f --- /dev/null +++ b/package.go @@ -0,0 +1,2 @@ +// Package proton implements types for accessing the Proton API. +package proton diff --git a/paging.go b/paging.go new file mode 100644 index 0000000..003ca79 --- /dev/null +++ b/paging.go @@ -0,0 +1,33 @@ +package proton + +import ( + "context" + "runtime" + + "github.com/bradenaw/juniper/iterator" + "github.com/bradenaw/juniper/parallel" + "github.com/bradenaw/juniper/stream" +) + +const maxPageSize = 150 + +func fetchPaged[T any]( + ctx context.Context, + total, pageSize int, + fn func(ctx context.Context, page, pageSize int) ([]T, error), +) ([]T, error) { + return stream.Collect(ctx, stream.Flatten(parallel.MapStream( + ctx, + stream.FromIterator(iterator.Counter(total/pageSize+1)), + runtime.NumCPU(), + runtime.NumCPU(), + func(ctx context.Context, page int) (stream.Stream[T], error) { + values, err := fn(ctx, page, pageSize) + if err != nil { + return nil, err + } + + return stream.FromIterator(iterator.Slice(values)), nil + }, + ))) +} diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..eaa257f --- /dev/null +++ b/pool.go @@ -0,0 +1,166 @@ +package proton + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/ProtonMail/gluon/queue" +) + +// ErrJobCancelled indicates the job was cancelled. +var ErrJobCancelled = errors.New("job cancelled by surrounding context") + +// Pool is a worker pool that handles input of type In and returns results of type Out. +type Pool[In comparable, Out any] struct { + queue *queue.QueuedChannel[*job[In, Out]] + wg sync.WaitGroup +} + +// doneFunc must be called to free up pool resources. +type doneFunc func() + +// New returns a new pool. +func NewPool[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] { + pool := &Pool[In, Out]{ + queue: queue.NewQueuedChannel[*job[In, Out]](0, 0), + } + + for i := 0; i < size; i++ { + pool.wg.Add(1) + + go func() { + defer pool.wg.Done() + + for job := range pool.queue.GetChannel() { + select { + case <-job.ctx.Done(): + job.postFailure(ErrJobCancelled) + + default: + res, err := work(job.ctx, job.req) + if err != nil { + job.postFailure(err) + } else { + job.postSuccess(res) + } + + job.waitDone() + } + } + }() + } + + return pool +} + +// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred. +func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(int, In, Out, error) error) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var ( + wg sync.WaitGroup + errList []error + lock sync.Mutex + ) + + for i, req := range reqs { + req := req + + wg.Add(1) + + go func(index int) { + defer wg.Done() + + job, done, err := pool.newJob(ctx, req) + if err != nil { + lock.Lock() + defer lock.Unlock() + + // Cancel ongoing jobs. + cancel() + + // Collect the error. + errList = append(errList, err) + + return + } + + defer done() + + res, err := job.result() + + if err := fn(index, req, res, err); err != nil { + lock.Lock() + defer lock.Unlock() + + // Cancel ongoing jobs. + cancel() + + // Collect the error. + errList = append(errList, err) + } + }(i) + } + + wg.Wait() + + // TODO: Join the errors somehow? + if len(errList) > 0 { + return errList[0] + } + + return nil +} + +// ProcessAll submits jobs to the pool. All results are returned once available. +func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) ([]Out, error) { + data := make([]Out, len(reqs)) + + if err := pool.Process(ctx, reqs, func(index int, req In, res Out, err error) error { + if err != nil { + return err + } + + data[index] = res + + return nil + }); err != nil { + return nil, err + } + + return data, nil +} + +// ProcessOne submits one job to the pool and returns the result. +func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) { + job, done, err := pool.newJob(ctx, req) + if err != nil { + var o Out + return o, err + } + + defer done() + + return job.result() +} + +func (pool *Pool[In, Out]) Done() { + pool.queue.Close() + pool.wg.Wait() +} + +// newJob submits a job to the pool. It returns a job handle and a DoneFunc. +// The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done, +// which frees up the worker in the pool for reuse. +func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc, error) { + job := newJob[In, Out](ctx, req) + + if !pool.queue.Enqueue(job) { + return nil, nil, fmt.Errorf("pool closed") + } + + return job, func() { close(job.done) }, nil +} diff --git a/pool_test.go b/pool_test.go new file mode 100644 index 0000000..bbc95d5 --- /dev/null +++ b/pool_test.go @@ -0,0 +1,173 @@ +package proton + +import ( + "context" + "errors" + "runtime" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPool_NewJob(t *testing.T) { + doubler := newDoubler(runtime.NumCPU()) + defer doubler.Done() + + job1, done1, err := doubler.newJob(context.Background(), 1) + require.NoError(t, err) + defer done1() + + job2, done2, err := doubler.newJob(context.Background(), 2) + require.NoError(t, err) + defer done2() + + res2, err := job2.result() + require.NoError(t, err) + + res1, err := job1.result() + require.NoError(t, err) + + assert.Equal(t, 2, res1) + assert.Equal(t, 4, res2) +} + +func TestPool_NewJob_Done(t *testing.T) { + // Create a doubler pool with 2 workers. + doubler := newDoubler(2) + defer doubler.Done() + + // Start two jobs. Don't mark the jobs as done yet. + job1, done1, err := doubler.newJob(context.Background(), 1) + require.NoError(t, err) + job2, done2, err := doubler.newJob(context.Background(), 2) + require.NoError(t, err) + + // Get the first result. + res1, _ := job1.result() + assert.Equal(t, 2, res1) + + // Get the first result. + res2, _ := job2.result() + assert.Equal(t, 4, res2) + + // Additional jobs will wait. + job3, done3, err := doubler.newJob(context.Background(), 3) + require.NoError(t, err) + job4, done4, err := doubler.newJob(context.Background(), 4) + require.NoError(t, err) + + // Channel to collect results from jobs 3 and 4. + resCh := make(chan int, 2) + + go func() { + res, _ := job3.result() + resCh <- res + }() + + go func() { + res, _ := job4.result() + resCh <- res + }() + + // Mark jobs 1 and 2 as done, freeing up the workers. + done1() + done2() + + assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh}) + + // Mark jobs 3 and 4 as done, freeing up the workers. + done3() + done4() +} + +func TestPool_Process(t *testing.T) { + doubler := newDoubler(runtime.NumCPU()) + defer doubler.Done() + + res := make([]int, 5) + + require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(index, reqVal, resVal int, err error) error { + require.NoError(t, err) + + res[index] = resVal + + return nil + })) + + assert.Equal(t, []int{ + 2, + 4, + 6, + 8, + 10, + }, res) +} + +func TestPool_Process_Error(t *testing.T) { + doubler := newDoublerWithError(runtime.NumCPU()) + defer doubler.Done() + + assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_int, _ int, _ int, err error) error { + return err + })) +} + +func TestPool_Process_Parallel(t *testing.T) { + doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond) + defer doubler.Done() + + var wg sync.WaitGroup + + for i := 0; i < 8; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, _ int, err error) error { + return nil + })) + }() + } + + wg.Wait() +} + +func TestPool_ProcessAll(t *testing.T) { + doubler := newDoubler(runtime.NumCPU()) + defer doubler.Done() + + res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5}) + require.NoError(t, err) + + assert.Equal(t, []int{ + 2, + 4, + 6, + 8, + 10, + }, res) +} + +func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] { + return NewPool(workers, func(ctx context.Context, req int) (int, error) { + if len(delay) > 0 { + time.Sleep(delay[0]) + } + + return 2 * req, nil + }) +} + +func newDoublerWithError(workers int) *Pool[int, int] { + return NewPool(workers, func(ctx context.Context, req int) (int, error) { + if req%2 == 0 { + return 0, errors.New("oops") + } + + return 2 * req, nil + }) +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..c2c64f0 --- /dev/null +++ b/response.go @@ -0,0 +1,84 @@ +package proton + +import ( + "fmt" + "net/http" + "strconv" + "time" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/go-resty/resty/v2" +) + +type Code int + +const ( + SuccessCode Code = 1000 + MultiCode Code = 1001 + InvalidValue Code = 2001 + AppVersionMissingCode Code = 5001 + AppVersionBadCode Code = 5003 + PasswordWrong Code = 8002 + HumanVerificationRequired Code = 9001 + PaidPlanRequired Code = 10004 + AuthRefreshTokenInvalid Code = 10013 +) + +type Error struct { + Code Code + Message string `json:"Error"` +} + +func (err Error) Error() string { + return err.Message +} + +func catchAPIError(_ *resty.Client, res *resty.Response) error { + if !res.IsError() { + return nil + } + + var err error + + if apiErr, ok := res.Error().(*Error); ok { + err = apiErr + } else { + err = fmt.Errorf("%v", res.Status()) + } + + return fmt.Errorf("%v: %w", res.StatusCode(), err) +} + +func updateTime(_ *resty.Client, res *resty.Response) error { + date, err := time.Parse(time.RFC1123, res.Header().Get("Date")) + if err != nil { + return err + } + + crypto.UpdateTime(date.Unix()) + + return nil +} + +func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) { + if res.StatusCode() == http.StatusTooManyRequests { + if after := res.Header().Get("Retry-After"); after != "" { + seconds, err := strconv.Atoi(after) + if err != nil { + return 0, err + } + + return time.Duration(seconds) * time.Second, nil + } + } + + return 0, nil +} + +func catchTooManyRequests(res *resty.Response, _ error) bool { + return res.StatusCode() == http.StatusTooManyRequests +} + +func catchDialError(res *resty.Response, err error) bool { + return res.RawResponse == nil +} diff --git a/salt.go b/salt.go new file mode 100644 index 0000000..72de065 --- /dev/null +++ b/salt.go @@ -0,0 +1,21 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetSalts(ctx context.Context) (Salts, error) { + var res struct { + KeySalts []Salt + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/core/v4/keys/salts") + }); err != nil { + return nil, err + } + + return res.KeySalts, nil +} diff --git a/salt_types.go b/salt_types.go new file mode 100644 index 0000000..8f06861 --- /dev/null +++ b/salt_types.go @@ -0,0 +1,37 @@ +package proton + +import ( + "encoding/base64" + "fmt" + + "github.com/ProtonMail/go-srp" + "github.com/bradenaw/juniper/xslices" +) + +type Salt struct { + ID, KeySalt string +} + +type Salts []Salt + +func (salts Salts) SaltForKey(keyPass []byte, keyID string) ([]byte, error) { + idx := xslices.IndexFunc(salts, func(salt Salt) bool { + return salt.ID == keyID + }) + + if idx < 0 { + return nil, fmt.Errorf("no salt found for key %s", keyID) + } + + keySalt, err := base64.StdEncoding.DecodeString(salts[idx].KeySalt) + if err != nil { + return nil, err + } + + saltedKeyPass, err := srp.MailboxPassword(keyPass, keySalt) + if err != nil { + return nil, nil + } + + return saltedKeyPass[len(saltedKeyPass)-31:], nil +} diff --git a/server/addresses.go b/server/addresses.go new file mode 100644 index 0000000..0e753d4 --- /dev/null +++ b/server/addresses.go @@ -0,0 +1,74 @@ +package server + +import ( + "net/http" + + "github.com/ProtonMail/go-proton-api" + "github.com/bradenaw/juniper/xslices" + "github.com/gin-gonic/gin" + "golang.org/x/exp/slices" +) + +func (s *Server) handleGetAddresses() gin.HandlerFunc { + return func(c *gin.Context) { + addresses, err := s.b.GetAddresses(c.GetString("UserID")) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Addresses": addresses, + }) + } +} + +func (s *Server) handleGetAddress() gin.HandlerFunc { + return func(c *gin.Context) { + addresses, err := s.b.GetAddresses(c.GetString("UserID")) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Address": addresses[xslices.IndexFunc(addresses, func(address proton.Address) bool { + return address.ID == c.Param("addressID") + })], + }) + } +} + +func (s *Server) handlePutAddressesOrder() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.OrderAddressesReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + addresses, err := s.b.GetAddresses(c.GetString("UserID")) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + if len(req.AddressIDs) != len(addresses) { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + for _, address := range addresses { + if !slices.Contains(req.AddressIDs, address.ID) { + c.AbortWithStatus(http.StatusBadRequest) + return + } + } + + if err := s.b.SetAddressOrder(c.GetString("UserID"), req.AddressIDs); err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + } +} diff --git a/server/attachments.go b/server/attachments.go new file mode 100644 index 0000000..4ad5d5e --- /dev/null +++ b/server/attachments.go @@ -0,0 +1,66 @@ +package server + +import ( + "io" + "mime/multipart" + "net/http" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/go-proton-api" + "github.com/gin-gonic/gin" +) + +func (s *Server) handlePostMailAttachments() gin.HandlerFunc { + return func(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + attachment, err := s.b.CreateAttachment( + c.GetString("UserID"), + form.Value["MessageID"][0], + form.Value["Filename"][0], + rfc822.MIMEType(form.Value["MIMEType"][0]), + proton.Disposition(form.Value["Disposition"][0]), + mustReadFileHeader(form.File["KeyPackets"][0]), + mustReadFileHeader(form.File["DataPacket"][0]), + string(mustReadFileHeader(form.File["Signature"][0])), + ) + if err != nil { + _ = c.AbortWithError(http.StatusUnprocessableEntity, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Attachment": attachment, + }) + } +} + +func (s *Server) handleGetMailAttachment() gin.HandlerFunc { + return func(c *gin.Context) { + attData, err := s.b.GetAttachment(c.Param("attachID")) + if err != nil { + _ = c.AbortWithError(http.StatusUnprocessableEntity, err) + return + } + + c.Data(http.StatusOK, "application/octet-stream", attData) + } +} + +func mustReadFileHeader(fh *multipart.FileHeader) []byte { + f, err := fh.Open() + if err != nil { + panic(err) + } + + data, err := io.ReadAll(f) + if err != nil { + panic(err) + } + + return data +} diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 0000000..c09c175 --- /dev/null +++ b/server/auth.go @@ -0,0 +1,124 @@ +package server + +import ( + "encoding/base64" + "net/http" + + "github.com/ProtonMail/go-proton-api" + "github.com/gin-gonic/gin" +) + +func (s *Server) handlePostAuthInfo() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.AuthInfoReq + + if err := c.BindJSON(&req); err != nil { + return + } + + info, err := s.b.NewAuthInfo(req.Username) + if err != nil { + _ = c.AbortWithError(http.StatusUnauthorized, err) + return + } + + c.JSON(http.StatusOK, info) + } +} + +func (s *Server) handlePostAuth() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.AuthReq + + if err := c.BindJSON(&req); err != nil { + return + } + + clientEphemeral, err := base64.StdEncoding.DecodeString(req.ClientEphemeral) + if err != nil { + _ = c.AbortWithError(http.StatusBadRequest, err) + return + } + + clientProof, err := base64.StdEncoding.DecodeString(req.ClientProof) + if err != nil { + _ = c.AbortWithError(http.StatusBadRequest, err) + return + } + + auth, err := s.b.NewAuth(req.Username, clientEphemeral, clientProof, req.SRPSession) + if err != nil { + _ = c.AbortWithError(http.StatusUnauthorized, err) + return + } + + c.JSON(http.StatusOK, auth) + } +} + +func (s *Server) handlePostAuthRefresh() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.AuthRefreshReq + + if err := c.BindJSON(&req); err != nil { + return + } + + auth, err := s.b.NewAuthRef(req.UID, req.RefreshToken) + if err != nil { + _ = c.AbortWithError(http.StatusUnauthorized, err) + return + } + + c.JSON(http.StatusOK, auth) + } +} + +func (s *Server) handleDeleteAuth() gin.HandlerFunc { + return func(c *gin.Context) { + if err := s.b.DeleteSession(c.GetString("UserID"), c.GetString("AuthUID")); err != nil { + _ = c.AbortWithError(http.StatusUnauthorized, err) + return + } + } +} + +func (s *Server) handleGetAuthSessions() gin.HandlerFunc { + return func(c *gin.Context) { + sessions, err := s.b.GetSessions(c.GetString("UserID")) + if err != nil { + _ = c.AbortWithError(http.StatusInternalServerError, err) + return + } + + c.JSON(http.StatusOK, gin.H{"Sessions": sessions}) + } +} + +func (s *Server) handleDeleteAuthSessions() gin.HandlerFunc { + return func(c *gin.Context) { + sessions, err := s.b.GetSessions(c.GetString("UserID")) + if err != nil { + _ = c.AbortWithError(http.StatusInternalServerError, err) + return + } + + for _, session := range sessions { + if session.UID != c.GetString("AuthUID") { + if err := s.b.DeleteSession(c.GetString("UserID"), session.UID); err != nil { + _ = c.AbortWithError(http.StatusInternalServerError, err) + return + } + } + } + } +} + +func (s *Server) handleDeleteAuthSession() gin.HandlerFunc { + return func(c *gin.Context) { + if err := s.b.DeleteSession(c.GetString("UserID"), c.Param("authUID")); err != nil { + _ = c.AbortWithError(http.StatusInternalServerError, err) + return + } + } +} diff --git a/server/backend/account.go b/server/backend/account.go new file mode 100644 index 0000000..45b8844 --- /dev/null +++ b/server/backend/account.go @@ -0,0 +1,87 @@ +package backend + +import ( + "sync" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/xslices" + "github.com/google/uuid" +) + +type account struct { + userID string + username string + addresses map[string]*address + + auth map[string]auth + authLock sync.RWMutex + + keys []key + salt []byte + verifier []byte + + labelIDs []string + messageIDs []string + updateIDs []ID +} + +func newAccount(userID, username string, armKey string, salt, verifier []byte) *account { + return &account{ + userID: userID, + username: username, + addresses: make(map[string]*address), + + auth: make(map[string]auth), + keys: []key{{keyID: uuid.NewString(), key: armKey}}, + salt: salt, + verifier: verifier, + } +} + +func (acc *account) toUser() proton.User { + return proton.User{ + ID: acc.userID, + Name: acc.username, + DisplayName: acc.username, + Email: acc.primary().email, + Keys: xslices.Map(acc.keys, func(key key) proton.Key { + privKey, err := crypto.NewKeyFromArmored(key.key) + if err != nil { + panic(err) + } + + rawKey, err := privKey.Serialize() + if err != nil { + panic(err) + } + + return proton.Key{ + ID: key.keyID, + PrivateKey: rawKey, + Primary: key == acc.keys[0], + Active: true, + } + }), + } +} + +func (acc *account) primary() *address { + for _, addr := range acc.addresses { + if addr.order == 1 { + return addr + } + } + + panic("no primary address") +} + +func (acc *account) getAddr(email string) (*address, bool) { + for _, addr := range acc.addresses { + if addr.email == email { + return addr, true + } + } + + return nil, false +} diff --git a/server/backend/address.go b/server/backend/address.go new file mode 100644 index 0000000..94ea4b5 --- /dev/null +++ b/server/backend/address.go @@ -0,0 +1,49 @@ +package backend + +import ( + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/xslices" +) + +type address struct { + addrID string + email string + order int + keys []key +} + +func (add *address) toAddress() proton.Address { + return proton.Address{ + ID: add.addrID, + Email: add.email, + + Send: true, + Receive: true, + Status: proton.AddressStatusEnabled, + + Order: add.order, + DisplayName: add.email, + + Keys: xslices.Map(add.keys, func(key key) proton.Key { + privKey, err := crypto.NewKeyFromArmored(key.key) + if err != nil { + panic(err) + } + + rawKey, err := privKey.Serialize() + if err != nil { + panic(err) + } + + return proton.Key{ + ID: key.keyID, + PrivateKey: rawKey, + Token: key.tok, + Signature: key.sig, + Primary: key == add.keys[0], + Active: true, + } + }), + } +} diff --git a/server/backend/api.go b/server/backend/api.go new file mode 100644 index 0000000..65ffbaf --- /dev/null +++ b/server/backend/api.go @@ -0,0 +1,772 @@ +package backend + +import ( + "encoding/base64" + "errors" + "fmt" + "net/mail" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/xslices" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +func (b *Backend) GetUser(userID string) (proton.User, error) { + return withAcc(b, userID, func(acc *account) (proton.User, error) { + return acc.toUser(), nil + }) +} + +func (b *Backend) GetKeySalts(userID string) ([]proton.Salt, error) { + return withAcc(b, userID, func(acc *account) ([]proton.Salt, error) { + return xslices.Map(acc.keys, func(key key) proton.Salt { + return proton.Salt{ + ID: key.keyID, + KeySalt: base64.StdEncoding.EncodeToString(acc.salt), + } + }), nil + }) +} + +func (b *Backend) GetMailSettings(userID string) (proton.MailSettings, error) { + return withAcc(b, userID, func(acc *account) (proton.MailSettings, error) { + return proton.MailSettings{ + DisplayName: acc.username, + DraftMIMEType: rfc822.TextHTML, + }, nil + }) +} + +func (b *Backend) GetAddressID(email string) (string, error) { + return withAccEmail(b, email, func(acc *account) (string, error) { + addr, ok := acc.getAddr(email) + if !ok { + return "", fmt.Errorf("no such address: %s", email) + } + + return addr.addrID, nil + }) +} + +func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) { + return withAcc(b, userID, func(acc *account) ([]proton.Address, error) { + return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address { + return add.toAddress() + }), nil + }) +} + +func (b *Backend) SetAddressOrder(userID string, addrIDs []string) error { + return b.withAcc(userID, func(acc *account) error { + for i, addrID := range addrIDs { + if add, ok := acc.addresses[addrID]; ok { + add.order = i + 1 + } else { + return fmt.Errorf("no such address: %s", addrID) + } + } + + return nil + }) +} + +func (b *Backend) HasLabel(userID, labelName string) (string, bool, error) { + labels, err := b.GetLabels(userID) + if err != nil { + return "", false, err + } + + for _, label := range labels { + if label.Name == labelName { + return label.ID, true, nil + } + } + + return "", false, nil +} + +func (b *Backend) GetLabel(userID, labelID string) (proton.Label, error) { + labels, err := b.GetLabels(userID) + if err != nil { + return proton.Label{}, err + } + + for _, label := range labels { + if label.ID == labelID { + return label, nil + } + } + + return proton.Label{}, fmt.Errorf("no such label: %s", labelID) +} + +func (b *Backend) GetLabels(userID string, types ...proton.LabelType) ([]proton.Label, error) { + return withAcc(b, userID, func(acc *account) ([]proton.Label, error) { + return withLabels(b, func(labels map[string]*label) ([]proton.Label, error) { + res := xslices.Map(acc.labelIDs, func(labelID string) proton.Label { + return labels[labelID].toLabel(labels) + }) + + for labelName, labelID := range map[string]string{ + "Inbox": proton.InboxLabel, + "AllDrafts": proton.AllDraftsLabel, + "AllSent": proton.AllSentLabel, + "Trash": proton.TrashLabel, + "Spam": proton.SpamLabel, + "All Mail": proton.AllMailLabel, + "Archive": proton.ArchiveLabel, + "Sent": proton.SentLabel, + "Drafts": proton.DraftsLabel, + "Outbox": proton.OutboxLabel, + "Starred": proton.StarredLabel, + } { + res = append(res, proton.Label{ + ID: labelID, + Name: labelName, + Path: []string{labelName}, + Type: proton.LabelTypeSystem, + }) + } + + if len(types) > 0 { + res = xslices.Filter(res, func(label proton.Label) bool { + return slices.Contains(types, label.Type) + }) + } + + return res, nil + }) + }) +} + +func (b *Backend) CreateLabel(userID, labelName, parentID string, labelType proton.LabelType) (proton.Label, error) { + return withAcc(b, userID, func(acc *account) (proton.Label, error) { + return withLabels(b, func(labels map[string]*label) (proton.Label, error) { + if parentID != "" { + if labelType != proton.LabelTypeFolder { + return proton.Label{}, fmt.Errorf("parentID can only be set for folders") + } + + if _, ok := labels[parentID]; !ok { + return proton.Label{}, fmt.Errorf("no such parent label: %s", parentID) + } + } + + label := newLabel(labelName, parentID, labelType) + + labels[label.labelID] = label + + updateID, err := b.newUpdate(&labelCreated{labelID: label.labelID}) + if err != nil { + return proton.Label{}, err + } + + acc.labelIDs = append(acc.labelIDs, label.labelID) + acc.updateIDs = append(acc.updateIDs, updateID) + + return label.toLabel(labels), nil + }) + }) +} + +func (b *Backend) UpdateLabel(userID, labelID, name, parentID string) (proton.Label, error) { + return withAcc(b, userID, func(acc *account) (proton.Label, error) { + return withLabels(b, func(labels map[string]*label) (proton.Label, error) { + if parentID != "" { + if labels[labelID].labelType != proton.LabelTypeFolder { + return proton.Label{}, fmt.Errorf("parentID can only be set for folders") + } + + if _, ok := labels[parentID]; !ok { + return proton.Label{}, fmt.Errorf("no such parent label: %s", parentID) + } + } + + labels[labelID].name = name + labels[labelID].parentID = parentID + + updateID, err := b.newUpdate(&labelUpdated{labelID: labelID}) + if err != nil { + return proton.Label{}, err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return labels[labelID].toLabel(labels), nil + }) + }) +} + +func (b *Backend) DeleteLabel(userID, labelID string) error { + return b.withAcc(userID, func(acc *account) error { + return b.withLabels(func(labels map[string]*label) error { + if _, ok := labels[labelID]; !ok { + return errors.New("label not found") + } + + for _, labelID := range getLabelIDsToDelete(labelID, labels) { + delete(labels, labelID) + + updateID, err := b.newUpdate(&labelDeleted{labelID: labelID}) + if err != nil { + return err + } + + acc.labelIDs = xslices.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID }) + acc.updateIDs = append(acc.updateIDs, updateID) + } + + return nil + }) + }) +} + +func (b *Backend) CountMessages(userID string) (int, error) { + return withAcc(b, userID, func(acc *account) (int, error) { + return len(acc.messageIDs), nil + }) +} + +func (b *Backend) GetMessageIDs(userID string, afterID string, limit int) ([]string, error) { + return withAcc(b, userID, func(acc *account) ([]string, error) { + if len(acc.messageIDs) == 0 { + return nil, nil + } + + var lo, hi int + + if afterID == "" { + lo = 0 + } else { + lo = slices.Index(acc.messageIDs, afterID) + 1 + } + + if limit == 0 { + hi = len(acc.messageIDs) + } else { + hi = lo + limit + + if hi > len(acc.messageIDs) { + hi = len(acc.messageIDs) + } + } + + return acc.messageIDs[lo:hi], nil + }) +} + +func (b *Backend) GetMessages(userID string, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error) { + return withAcc(b, userID, func(acc *account) ([]proton.MessageMetadata, error) { + return withMessages(b, func(messages map[string]*message) ([]proton.MessageMetadata, error) { + if len(acc.messageIDs) == 0 { + return nil, nil + } + + metadata := xslices.Map(xslices.Chunk(acc.messageIDs, pageSize)[page], func(messageID string) proton.MessageMetadata { + return messages[messageID].toMetadata() + }) + + if len(filter.ID) > 0 { + metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool { + return slices.Contains(filter.ID, metadata.ID) + }) + } + + if len(filter.AddressID) != 0 { + metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool { + return filter.AddressID == metadata.AddressID + }) + } + + if len(filter.ExternalID) != 0 { + metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool { + return filter.ExternalID != metadata.ExternalID + }) + } + + if len(filter.LabelID) != 0 { + metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool { + return slices.Contains(metadata.LabelIDs, filter.LabelID) + }) + } + + return metadata, nil + }) + }) +} + +func (b *Backend) GetMessage(userID, messageID string) (proton.Message, error) { + return withAcc(b, userID, func(acc *account) (proton.Message, error) { + return withMessages(b, func(messages map[string]*message) (proton.Message, error) { + return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) { + message, ok := messages[messageID] + if !ok { + return proton.Message{}, errors.New("no such message") + } + + return message.toMessage(atts), nil + }) + }) + }) +} + +func (b *Backend) SetMessagesRead(userID string, read bool, messageIDs ...string) error { + return b.withAcc(userID, func(acc *account) error { + return b.withMessages(func(messages map[string]*message) error { + for _, messageID := range messageIDs { + messages[messageID].unread = !read + + updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + } + + return nil + }) + }) +} + +func (b *Backend) LabelMessages(userID, labelID string, messageIDs ...string) error { + if labelID == proton.AllMailLabel || labelID == proton.AllDraftsLabel || labelID == proton.AllSentLabel { + return fmt.Errorf("not allowed") + } + + return b.withAcc(userID, func(acc *account) error { + return b.withMessages(func(messages map[string]*message) error { + return b.withLabels(func(labels map[string]*label) error { + for _, messageID := range messageIDs { + message, ok := messages[messageID] + if !ok { + continue + } + + message.addLabel(labelID, labels) + + updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + } + + return nil + }) + }) + }) +} + +func (b *Backend) UnlabelMessages(userID, labelID string, messageIDs ...string) error { + if labelID == proton.AllMailLabel || labelID == proton.AllDraftsLabel || labelID == proton.AllSentLabel { + return fmt.Errorf("not allowed") + } + + return b.withAcc(userID, func(acc *account) error { + return b.withMessages(func(messages map[string]*message) error { + return b.withLabels(func(labels map[string]*label) error { + for _, messageID := range messageIDs { + messages[messageID].remLabel(labelID, labels) + + updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + } + + return nil + }) + }) + }) +} + +func (b *Backend) DeleteMessage(userID, messageID string) error { + return b.withAcc(userID, func(acc *account) error { + return b.withMessages(func(messages map[string]*message) error { + message, ok := messages[messageID] + if !ok { + return errors.New("no such message") + } + + for _, attID := range message.attIDs { + if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool { + return att.attDataID == b.attachments[attID].attDataID + }) == 1 { + delete(b.attData, b.attachments[attID].attDataID) + } + + delete(b.attachments, attID) + } + + delete(b.messages, messageID) + + updateID, err := b.newUpdate(&messageDeleted{messageID: messageID}) + if err != nil { + return err + } + + acc.messageIDs = xslices.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID }) + acc.updateIDs = append(acc.updateIDs, updateID) + + return nil + }) + }) +} + +func (b *Backend) CreateDraft( + userID, addrID string, + subject string, + sender *mail.Address, + toList, ccList, bccList []*mail.Address, + armBody string, + mimeType rfc822.MIMEType, + externalID string, +) (proton.Message, error) { + return withAcc(b, userID, func(acc *account) (proton.Message, error) { + return withMessages(b, func(messages map[string]*message) (proton.Message, error) { + msg := newMessage(addrID, subject, sender, toList, ccList, bccList, armBody, mimeType, externalID) + + messages[msg.messageID] = msg + + updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID}) + if err != nil { + return proton.Message{}, err + } + + acc.messageIDs = append(acc.messageIDs, msg.messageID) + acc.updateIDs = append(acc.updateIDs, updateID) + + return msg.toMessage(nil), nil + }) + }) +} + +func (b *Backend) SendMessage(userID, messageID string, packages []*proton.MessagePackage) (proton.Message, error) { + return withAcc(b, userID, func(acc *account) (proton.Message, error) { + return withMessages(b, func(messages map[string]*message) (proton.Message, error) { + return withLabels(b, func(labels map[string]*label) (proton.Message, error) { + return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) { + msg := messages[messageID] + msg.flags |= proton.MessageFlagSent + msg.addLabel(proton.SentLabel, labels) + + updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) + if err != nil { + return proton.Message{}, err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + for _, pkg := range packages { + bodyData, err := base64.StdEncoding.DecodeString(pkg.Body) + if err != nil { + return proton.Message{}, err + } + + for email, recipient := range pkg.Addresses { + if recipient.Type != proton.InternalScheme { + continue + } + + if err := b.withAccEmail(email, func(acc *account) error { + bodyKey, err := base64.StdEncoding.DecodeString(recipient.BodyKeyPacket) + if err != nil { + return err + } + + armBody, err := crypto.NewPGPSplitMessage(bodyKey, bodyData).GetPGPMessage().GetArmored() + if err != nil { + return err + } + + addrID, err := b.GetAddressID(email) + if err != nil { + return err + } + + newMsg := newMessage( + addrID, + msg.subject, + msg.sender, + msg.toList, + msg.ccList, + nil, // BCC is not sent to the recipient + armBody, + msg.mimeType, + msg.externalID, + ) + newMsg.flags |= proton.MessageFlagReceived + newMsg.addLabel(proton.InboxLabel, labels) + newMsg.unread = true + messages[newMsg.messageID] = newMsg + + for _, attID := range msg.attIDs { + attKey, err := base64.StdEncoding.DecodeString(recipient.AttachmentKeyPackets[attID]) + if err != nil { + return err + } + + att := newAttachment( + atts[attID].filename, + atts[attID].mimeType, + atts[attID].disposition, + attKey, + atts[attID].attDataID, + atts[attID].armSig, + ) + atts[att.attachID] = att + messages[newMsg.messageID].attIDs = append(messages[newMsg.messageID].attIDs, att.attachID) + } + + updateID, err := b.newUpdate(&messageCreated{messageID: newMsg.messageID}) + if err != nil { + return err + } + + acc.messageIDs = append(acc.messageIDs, newMsg.messageID) + acc.updateIDs = append(acc.updateIDs, updateID) + + return nil + }); err != nil { + return proton.Message{}, err + } + } + } + + return msg.toMessage(atts), nil + }) + }) + }) + }) +} + +func (b *Backend) CreateAttachment( + userID string, + messageID string, + filename string, + mimeType rfc822.MIMEType, + disposition proton.Disposition, + keyPackets, dataPacket []byte, + armSig string, +) (proton.Attachment, error) { + return withAcc(b, userID, func(acc *account) (proton.Attachment, error) { + return withMessages(b, func(messages map[string]*message) (proton.Attachment, error) { + return withAtts(b, func(atts map[string]*attachment) (proton.Attachment, error) { + att := newAttachment( + filename, + mimeType, + disposition, + keyPackets, + b.createAttData(dataPacket), + armSig, + ) + + atts[att.attachID] = att + + messages[messageID].attIDs = append(messages[messageID].attIDs, att.attachID) + + updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) + if err != nil { + return proton.Attachment{}, err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return att.toAttachment(), nil + }) + }) + }) +} + +func (b *Backend) GetAttachment(attachID string) ([]byte, error) { + return withAtts(b, func(atts map[string]*attachment) ([]byte, error) { + att, ok := atts[attachID] + if !ok { + return nil, fmt.Errorf("no such attachment: %s", attachID) + } + + return b.attData[att.attDataID], nil + }) +} + +func (b *Backend) GetLatestEventID(userID string) (string, error) { + return withAcc(b, userID, func(acc *account) (string, error) { + return acc.updateIDs[len(acc.updateIDs)-1].String(), nil + }) +} + +func (b *Backend) GetEvent(userID, rawEventID string) (proton.Event, error) { + var eventID ID + + if err := eventID.FromString(rawEventID); err != nil { + return proton.Event{}, fmt.Errorf("invalid event ID: %s", rawEventID) + } + + return withAcc(b, userID, func(acc *account) (proton.Event, error) { + return withMessages(b, func(messages map[string]*message) (proton.Event, error) { + return withLabels(b, func(labels map[string]*label) (proton.Event, error) { + updates, err := withUpdates(b, func(updates map[ID]update) ([]update, error) { + return merge(xslices.Map(acc.updateIDs[xslices.Index(acc.updateIDs, eventID)+1:], func(updateID ID) update { + return updates[updateID] + })), nil + }) + if err != nil { + return proton.Event{}, fmt.Errorf("failed to merge updates: %w", err) + } + + return buildEvent(updates, acc.addresses, messages, labels, acc.updateIDs[len(acc.updateIDs)-1].String()), nil + }) + }) + }) +} + +func (b *Backend) GetPublicKeys(email string) ([]proton.PublicKey, error) { + return withAccEmail(b, email, func(acc *account) ([]proton.PublicKey, error) { + var keys []proton.PublicKey + + for _, addr := range acc.addresses { + if addr.email == email { + for _, key := range addr.keys { + pubKey, err := key.getPubKey() + if err != nil { + return nil, err + } + + armKey, err := pubKey.GetArmoredPublicKey() + if err != nil { + return nil, err + } + + keys = append(keys, proton.PublicKey{ + Flags: proton.KeyStateTrusted | proton.KeyStateActive, + PublicKey: armKey, + }) + } + } + } + + return keys, nil + }) +} + +func getLabelIDsToDelete(labelID string, labels map[string]*label) []string { + labelIDs := []string{labelID} + + for _, label := range labels { + if label.parentID == labelID { + labelIDs = append(labelIDs, getLabelIDsToDelete(label.labelID, labels)...) + } + } + + return labelIDs +} + +func buildEvent( + updates []update, + addresses map[string]*address, + messages map[string]*message, + labels map[string]*label, + eventID string, +) proton.Event { + event := proton.Event{EventID: eventID} + + for _, update := range updates { + switch update := update.(type) { + case *userRefreshed: + event.Refresh = update.refresh + + case *messageCreated: + event.Messages = append(event.Messages, proton.MessageEvent{ + EventItem: proton.EventItem{ + ID: update.messageID, + Action: proton.EventCreate, + }, + + Message: messages[update.messageID].toMetadata(), + }) + + case *messageUpdated: + event.Messages = append(event.Messages, proton.MessageEvent{ + EventItem: proton.EventItem{ + ID: update.messageID, + Action: proton.EventUpdate, + }, + + Message: messages[update.messageID].toMetadata(), + }) + + case *messageDeleted: + event.Messages = append(event.Messages, proton.MessageEvent{ + EventItem: proton.EventItem{ + ID: update.messageID, + Action: proton.EventDelete, + }, + }) + + case *labelCreated: + event.Labels = append(event.Labels, proton.LabelEvent{ + EventItem: proton.EventItem{ + ID: update.labelID, + Action: proton.EventCreate, + }, + + Label: labels[update.labelID].toLabel(labels), + }) + + case *labelUpdated: + event.Labels = append(event.Labels, proton.LabelEvent{ + EventItem: proton.EventItem{ + ID: update.labelID, + Action: proton.EventUpdate, + }, + + Label: labels[update.labelID].toLabel(labels), + }) + + case *labelDeleted: + event.Labels = append(event.Labels, proton.LabelEvent{ + EventItem: proton.EventItem{ + ID: update.labelID, + Action: proton.EventDelete, + }, + }) + + case *addressCreated: + event.Addresses = append(event.Addresses, proton.AddressEvent{ + EventItem: proton.EventItem{ + ID: update.addressID, + Action: proton.EventCreate, + }, + + Address: addresses[update.addressID].toAddress(), + }) + + case *addressUpdated: + event.Addresses = append(event.Addresses, proton.AddressEvent{ + EventItem: proton.EventItem{ + ID: update.addressID, + Action: proton.EventCreate, + }, + + Address: addresses[update.addressID].toAddress(), + }) + + case *addressDeleted: + event.Addresses = append(event.Addresses, proton.AddressEvent{ + EventItem: proton.EventItem{ + ID: update.addressID, + Action: proton.EventDelete, + }, + }) + } + } + + return event +} diff --git a/server/backend/api_auth.go b/server/backend/api_auth.go new file mode 100644 index 0000000..6c01c58 --- /dev/null +++ b/server/backend/api_auth.go @@ -0,0 +1,127 @@ +package backend + +import ( + "encoding/base64" + "fmt" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-srp" + "github.com/google/uuid" +) + +func (b *Backend) NewAuthInfo(username string) (proton.AuthInfo, error) { + return withAccName(b, username, func(acc *account) (proton.AuthInfo, error) { + server, err := srp.NewServerFromSigned(modulus, acc.verifier, 2048) + if err != nil { + return proton.AuthInfo{}, nil + } + + challenge, err := server.GenerateChallenge() + if err != nil { + return proton.AuthInfo{}, nil + } + + session := uuid.NewString() + + b.srpLock.Lock() + defer b.srpLock.Unlock() + + b.srp[session] = server + + return proton.AuthInfo{ + Version: 4, + Modulus: modulus, + ServerEphemeral: base64.StdEncoding.EncodeToString(challenge), + Salt: base64.StdEncoding.EncodeToString(acc.salt), + SRPSession: session, + TwoFA: proton.TwoFAInfo{Enabled: proton.TwoFADisabled}, + }, nil + }) +} + +func (b *Backend) NewAuth(username string, ephemeral, proof []byte, session string) (proton.Auth, error) { + return withAccName(b, username, func(acc *account) (proton.Auth, error) { + b.srpLock.Lock() + defer b.srpLock.Unlock() + + server, ok := b.srp[session] + if !ok { + return proton.Auth{}, fmt.Errorf("invalid session") + } + + delete(b.srp, session) + + serverProof, err := server.VerifyProofs(ephemeral, proof) + if !ok { + return proton.Auth{}, fmt.Errorf("invalid proof: %w", err) + } + + authUID, auth := uuid.NewString(), newAuth(b.authLife) + + acc.authLock.Lock() + defer acc.authLock.Unlock() + + acc.auth[authUID] = auth + + return auth.toAuth(acc.userID, authUID, serverProof), nil + }) +} + +func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + acc.authLock.Lock() + defer acc.authLock.Unlock() + + auth, ok := acc.auth[authUID] + if !ok { + continue + } + + if auth.ref != authRef { + return proton.Auth{}, fmt.Errorf("invalid auth ref") + } + + newAuth := newAuth(b.authLife) + + acc.auth[authUID] = newAuth + + return newAuth.toAuth(acc.userID, authUID, nil), nil + } + + return proton.Auth{}, fmt.Errorf("invalid auth") +} + +func (b *Backend) VerifyAuth(authUID, authAcc string) (string, error) { + return withAccAuth(b, authUID, authAcc, func(acc *account) (string, error) { + return acc.userID, nil + }) +} + +func (b *Backend) GetSessions(userID string) ([]proton.AuthSession, error) { + return withAcc(b, userID, func(acc *account) ([]proton.AuthSession, error) { + acc.authLock.RLock() + defer acc.authLock.RUnlock() + + var sessions []proton.AuthSession + + for authUID, auth := range acc.auth { + sessions = append(sessions, auth.toAuthSession(authUID)) + } + + return sessions, nil + }) +} + +func (b *Backend) DeleteSession(userID, authUID string) error { + return b.withAcc(userID, func(acc *account) error { + acc.authLock.Lock() + defer acc.authLock.Unlock() + + delete(acc.auth, authUID) + + return nil + }) +} diff --git a/server/backend/attachment.go b/server/backend/attachment.go new file mode 100644 index 0000000..6c81dcf --- /dev/null +++ b/server/backend/attachment.go @@ -0,0 +1,66 @@ +package backend + +import ( + "encoding/base64" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/go-proton-api" + "github.com/google/uuid" +) + +func (b *Backend) createAttData(dataPacket []byte) string { + attDataID := uuid.NewString() + + b.attDataLock.Lock() + defer b.attDataLock.Unlock() + + b.attData[attDataID] = dataPacket + + return attDataID +} + +type attachment struct { + attachID string + attDataID string + + filename string + mimeType rfc822.MIMEType + disposition proton.Disposition + + keyPackets []byte + armSig string +} + +func newAttachment( + filename string, + mimeType rfc822.MIMEType, + disposition proton.Disposition, + keyPackets []byte, + dataPacketID string, + armSig string, +) *attachment { + return &attachment{ + attachID: uuid.NewString(), + attDataID: dataPacketID, + + filename: filename, + mimeType: mimeType, + disposition: disposition, + + keyPackets: keyPackets, + armSig: armSig, + } +} + +func (att *attachment) toAttachment() proton.Attachment { + return proton.Attachment{ + ID: att.attachID, + + Name: att.filename, + MIMEType: att.mimeType, + Disposition: att.disposition, + + KeyPackets: base64.StdEncoding.EncodeToString(att.keyPackets), + Signature: att.armSig, + } +} diff --git a/server/backend/backend.go b/server/backend/backend.go new file mode 100644 index 0000000..2141c64 --- /dev/null +++ b/server/backend/backend.go @@ -0,0 +1,544 @@ +package backend + +import ( + "fmt" + "net/mail" + "sync" + "time" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-srp" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/xslices" + "github.com/google/uuid" + "golang.org/x/exp/maps" +) + +type Backend struct { + accounts map[string]*account + accLock sync.RWMutex + + attachments map[string]*attachment + attLock sync.Mutex + + attData map[string][]byte + attDataLock sync.Mutex + + messages map[string]*message + msgLock sync.Mutex + + labels map[string]*label + lblLock sync.Mutex + + updates map[ID]update + updatesLock sync.RWMutex + + srp map[string]*srp.Server + srpLock sync.Mutex + + authLife time.Duration +} + +func New(authLife time.Duration) *Backend { + return &Backend{ + accounts: make(map[string]*account), + attachments: make(map[string]*attachment), + attData: make(map[string][]byte), + messages: make(map[string]*message), + labels: make(map[string]*label), + updates: make(map[ID]update), + srp: make(map[string]*srp.Server), + + authLife: authLife, + } +} + +func (b *Backend) SetAuthLife(authLife time.Duration) { + b.authLife = authLife +} + +func (b *Backend) CreateUser(username string, password []byte) (string, error) { + b.accLock.Lock() + defer b.accLock.Unlock() + + salt, err := crypto.RandomToken(16) + if err != nil { + return "", err + } + + passphrase, err := hashPassword(password, salt) + if err != nil { + return "", err + } + + srpAuth, err := srp.NewAuthForVerifier(password, modulus, salt) + if err != nil { + return "", err + } + + verifier, err := srpAuth.GenerateVerifier(2048) + if err != nil { + return "", err + } + + armKey, err := GenerateKey(username, username, passphrase, "rsa", 2048) + if err != nil { + return "", err + } + + userID := uuid.NewString() + + b.accounts[userID] = newAccount(userID, username, armKey, salt, verifier) + + return userID, nil +} + +func (b *Backend) RemoveUser(userID string) error { + b.accLock.Lock() + defer b.accLock.Unlock() + + user, ok := b.accounts[userID] + if !ok { + return fmt.Errorf("user %s does not exist", userID) + } + + for _, labelID := range user.labelIDs { + delete(b.labels, labelID) + } + + for _, messageID := range user.messageIDs { + for _, attID := range b.messages[messageID].attIDs { + if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool { + return att.attDataID == b.attachments[attID].attDataID + }) == 1 { + delete(b.attData, b.attachments[attID].attDataID) + } + + delete(b.attachments, attID) + } + + delete(b.messages, messageID) + } + + delete(b.accounts, userID) + + return nil +} + +func (b *Backend) RefreshUser(userID string, refresh proton.RefreshFlag) error { + return b.withAcc(userID, func(acc *account) error { + updateID, err := b.newUpdate(&userRefreshed{refresh: refresh}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return nil + }) +} + +func (b *Backend) CreateUserKey(userID string, password []byte) error { + b.accLock.Lock() + defer b.accLock.Unlock() + + user, ok := b.accounts[userID] + if !ok { + return fmt.Errorf("user %s does not exist", userID) + } + + salt, err := crypto.RandomToken(16) + if err != nil { + return err + } + + passphrase, err := hashPassword(password, salt) + if err != nil { + return err + } + + armKey, err := GenerateKey(user.username, user.username, passphrase, "rsa", 2048) + if err != nil { + return err + } + + user.keys = append(user.keys, key{keyID: uuid.NewString(), key: armKey}) + + return nil +} + +func (b *Backend) RemoveUserKey(userID, keyID string) error { + b.accLock.Lock() + defer b.accLock.Unlock() + + user, ok := b.accounts[userID] + if !ok { + return fmt.Errorf("user %s does not exist", userID) + } + + idx := xslices.IndexFunc(user.keys, func(key key) bool { + return key.keyID == keyID + }) + + if idx == -1 { + return fmt.Errorf("key %s does not exist", keyID) + } + + user.keys = append(user.keys[:idx], user.keys[idx+1:]...) + + return nil +} + +func (b *Backend) CreateAddress(userID, email string, password []byte) (string, error) { + return withAcc(b, userID, func(acc *account) (string, error) { + token, err := crypto.RandomToken(32) + if err != nil { + return "", err + } + + armKey, err := GenerateKey(acc.username, email, token, "rsa", 2048) + if err != nil { + return "", err + } + + passphrase, err := hashPassword([]byte(password), acc.salt) + if err != nil { + return "", err + } + + userKR, err := acc.keys[0].unlock(passphrase) + if err != nil { + return "", err + } + + encToken, sigToken, err := encryptWithSignature(userKR, token) + if err != nil { + return "", err + } + + addressID := uuid.NewString() + + acc.addresses[addressID] = &address{ + addrID: addressID, + email: email, + order: len(acc.addresses) + 1, + keys: []key{{ + keyID: uuid.NewString(), + key: armKey, + tok: encToken, + sig: sigToken, + }}, + } + + updateID, err := b.newUpdate(&addressCreated{addressID: addressID}) + if err != nil { + return "", err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return addressID, nil + }) +} + +func (b *Backend) CreateAddressKey(userID, addrID string, password []byte) error { + return b.withAcc(userID, func(acc *account) error { + token, err := crypto.RandomToken(32) + if err != nil { + return err + } + + armKey, err := GenerateKey(acc.username, acc.addresses[addrID].email, token, "rsa", 2048) + if err != nil { + return err + } + + passphrase, err := hashPassword([]byte(password), acc.salt) + if err != nil { + return err + } + + userKR, err := acc.keys[0].unlock(passphrase) + if err != nil { + return err + } + + encToken, sigToken, err := encryptWithSignature(userKR, token) + if err != nil { + return err + } + + acc.addresses[addrID].keys = append(acc.addresses[addrID].keys, key{ + keyID: uuid.NewString(), + key: armKey, + tok: encToken, + sig: sigToken, + }) + + updateID, err := b.newUpdate(&addressUpdated{addressID: addrID}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return nil + }) +} + +func (b *Backend) RemoveAddress(userID, addrID string) error { + return b.withAcc(userID, func(acc *account) error { + if _, ok := acc.addresses[addrID]; !ok { + return fmt.Errorf("address %s not found", addrID) + } + + delete(acc.addresses, addrID) + + updateID, err := b.newUpdate(&addressDeleted{addressID: addrID}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return nil + }) +} + +func (b *Backend) RemoveAddressKey(userID, addrID, keyID string) error { + return b.withAcc(userID, func(acc *account) error { + idx := xslices.IndexFunc(acc.addresses[addrID].keys, func(key key) bool { + return key.keyID == keyID + }) + + if idx < 0 { + return fmt.Errorf("key %s not found", keyID) + } + + acc.addresses[addrID].keys = append(acc.addresses[addrID].keys[:idx], acc.addresses[addrID].keys[idx+1:]...) + + updateID, err := b.newUpdate(&addressUpdated{addressID: addrID}) + if err != nil { + return err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return nil + }) +} + +func (b *Backend) CreateMessage( + userID, addrID string, + subject string, + sender *mail.Address, + toList, ccList, bccList []*mail.Address, + armBody string, + mimeType rfc822.MIMEType, + flags proton.MessageFlag, + unread, starred bool, +) (string, error) { + return withAcc(b, userID, func(acc *account) (string, error) { + return withMessages(b, func(messages map[string]*message) (string, error) { + msg := newMessage(addrID, subject, sender, toList, ccList, bccList, armBody, mimeType, "") + + msg.flags |= flags + msg.unread = unread + msg.starred = starred + + messages[msg.messageID] = msg + + updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID}) + if err != nil { + return "", err + } + + acc.messageIDs = append(acc.messageIDs, msg.messageID) + acc.updateIDs = append(acc.updateIDs, updateID) + + return msg.messageID, nil + }) + }) +} + +func (b *Backend) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) (string, error) { + return withAcc(b, userID, func(acc *account) (string, error) { + return withMessages(b, func(messages map[string]*message) (string, error) { + if _, ok := messages[draftID]; !ok { + return "", fmt.Errorf("message %q not found", draftID) + } + + messages[draftID].applyChanges(changes) + + updateID, err := b.newUpdate(&messageUpdated{messageID: draftID}) + if err != nil { + return "", err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return draftID, nil + }) + }) +} + +func (b *Backend) Encrypt(userID, addrID, decBody string) (string, error) { + return withAcc(b, userID, func(acc *account) (string, error) { + pubKey, err := acc.addresses[addrID].keys[0].getPubKey() + if err != nil { + return "", err + } + + kr, err := crypto.NewKeyRing(pubKey) + if err != nil { + return "", err + } + + enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(decBody), nil) + if err != nil { + return "", err + } + + return enc.GetArmored() + }) +} + +func (b *Backend) withAcc(userID string, fn func(acc *account) error) error { + b.accLock.RLock() + defer b.accLock.RUnlock() + + acc, ok := b.accounts[userID] + if !ok { + return fmt.Errorf("account %s not found", userID) + } + + return fn(acc) +} + +func (b *Backend) withAccEmail(email string, fn func(acc *account) error) error { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + for _, addr := range acc.addresses { + if addr.email == email { + return fn(acc) + } + } + } + + return fmt.Errorf("account %s not found", email) +} + +func withAcc[T any](b *Backend, userID string, fn func(acc *account) (T, error)) (T, error) { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + if acc.userID == userID { + return fn(acc) + } + } + + return *new(T), fmt.Errorf("account not found") +} + +func withAccName[T any](b *Backend, username string, fn func(acc *account) (T, error)) (T, error) { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + if acc.username == username { + return fn(acc) + } + } + + return *new(T), fmt.Errorf("account not found") +} + +func withAccEmail[T any](b *Backend, email string, fn func(acc *account) (T, error)) (T, error) { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + if _, ok := acc.getAddr(email); ok { + return fn(acc) + } + } + + return *new(T), fmt.Errorf("account not found") +} + +func withAccAuth[T any](b *Backend, authUID, authAcc string, fn func(acc *account) (T, error)) (T, error) { + b.accLock.RLock() + defer b.accLock.RUnlock() + + for _, acc := range b.accounts { + acc.authLock.RLock() + defer acc.authLock.RUnlock() + + auth, ok := acc.auth[authUID] + if !ok { + continue + } + + if auth.acc == authAcc { + return fn(acc) + } + } + + return *new(T), fmt.Errorf("account not found") +} + +func (b *Backend) withMessages(fn func(map[string]*message) error) error { + b.msgLock.Lock() + defer b.msgLock.Unlock() + + return fn(b.messages) +} + +func withMessages[T any](b *Backend, fn func(map[string]*message) (T, error)) (T, error) { + b.msgLock.Lock() + defer b.msgLock.Unlock() + + return fn(b.messages) +} + +func withAtts[T any](b *Backend, fn func(map[string]*attachment) (T, error)) (T, error) { + b.attLock.Lock() + defer b.attLock.Unlock() + + return fn(b.attachments) +} + +func (b *Backend) withLabels(fn func(map[string]*label) error) error { + b.lblLock.Lock() + defer b.lblLock.Unlock() + + return fn(b.labels) +} + +func withLabels[T any](b *Backend, fn func(map[string]*label) (T, error)) (T, error) { + b.lblLock.Lock() + defer b.lblLock.Unlock() + + return fn(b.labels) +} + +func (b *Backend) newUpdate(event update) (ID, error) { + return withUpdates(b, func(updates map[ID]update) (ID, error) { + updateID := ID(len(updates)) + + updates[updateID] = event + + return updateID, nil + }) +} + +func withUpdates[T any](b *Backend, fn func(map[ID]update) (T, error)) (T, error) { + b.updatesLock.Lock() + defer b.updatesLock.Unlock() + + return fn(b.updates) +} diff --git a/server/backend/crypto.go b/server/backend/crypto.go new file mode 100644 index 0000000..48b9716 --- /dev/null +++ b/server/backend/crypto.go @@ -0,0 +1,42 @@ +package backend + +import ( + "github.com/ProtonMail/go-srp" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/ProtonMail/gopenpgp/v2/helper" +) + +var GenerateKey = helper.GenerateKey + +func hashPassword(password, salt []byte) ([]byte, error) { + passphrase, err := srp.MailboxPassword(password, salt) + if err != nil { + return nil, err + } + + return passphrase[len(passphrase)-31:], nil +} + +func encryptWithSignature(kr *crypto.KeyRing, b []byte) (string, string, error) { + enc, err := kr.Encrypt(crypto.NewPlainMessage(b), nil) + if err != nil { + return "", "", err + } + + encArm, err := enc.GetArmored() + if err != nil { + return "", "", err + } + + sig, err := kr.SignDetached(crypto.NewPlainMessage(b)) + if err != nil { + return "", "", err + } + + sigArm, err := sig.GetArmored() + if err != nil { + return "", "", err + } + + return encArm, sigArm, nil +} diff --git a/server/backend/label.go b/server/backend/label.go new file mode 100644 index 0000000..bdb42bf --- /dev/null +++ b/server/backend/label.go @@ -0,0 +1,39 @@ +package backend + +import ( + "github.com/ProtonMail/go-proton-api" + "github.com/google/uuid" +) + +type label struct { + labelID string + parentID string + name string + labelType proton.LabelType + messageIDs map[string]struct{} +} + +func newLabel(labelName, parentID string, labelType proton.LabelType) *label { + return &label{ + labelID: uuid.NewString(), + parentID: parentID, + name: labelName, + labelType: labelType, + messageIDs: make(map[string]struct{}), + } +} + +func (label *label) toLabel(labels map[string]*label) proton.Label { + var path []string + + for labelID := label.labelID; labelID != ""; labelID = labels[labelID].parentID { + path = append([]string{labels[labelID].name}, path...) + } + + return proton.Label{ + ID: label.labelID, + Name: label.name, + Path: path, + Type: label.labelType, + } +} diff --git a/server/backend/message.go b/server/backend/message.go new file mode 100644 index 0000000..de63abb --- /dev/null +++ b/server/backend/message.go @@ -0,0 +1,300 @@ +package backend + +import ( + "net/mail" + "strings" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/go-proton-api" + "github.com/bradenaw/juniper/xslices" + "github.com/google/uuid" + "golang.org/x/exp/slices" +) + +type message struct { + messageID string + externalID string + addrID string + labelIDs []string + sysLabel *string + attIDs []string + + subject string + sender *mail.Address + toList []*mail.Address + ccList []*mail.Address + bccList []*mail.Address + + armBody string + mimeType rfc822.MIMEType + + flags proton.MessageFlag + unread bool + starred bool +} + +func newMessage( + addrID string, + subject string, + sender *mail.Address, + toList, ccList, bccList []*mail.Address, + armBody string, + mimeType rfc822.MIMEType, + externalID string, +) *message { + return &message{ + messageID: uuid.NewString(), + externalID: externalID, + addrID: addrID, + sysLabel: pointer(""), + + subject: subject, + sender: sender, + toList: toList, + ccList: ccList, + bccList: bccList, + + armBody: armBody, + mimeType: mimeType, + } +} + +func (msg *message) toMessage(att map[string]*attachment) proton.Message { + return proton.Message{ + MessageMetadata: msg.toMetadata(), + + Header: msg.getHeader(), + ParsedHeaders: msg.getParsedHeaders(), + Body: msg.armBody, + MIMEType: msg.mimeType, + Attachments: xslices.Map(msg.attIDs, func(attID string) proton.Attachment { + return att[attID].toAttachment() + }), + } +} + +func (msg *message) toMetadata() proton.MessageMetadata { + labelIDs := []string{proton.AllMailLabel} + + if msg.flags.Has(proton.MessageFlagSent) { + labelIDs = append(labelIDs, proton.AllSentLabel) + } + + if !msg.flags.HasAny(proton.MessageFlagSent, proton.MessageFlagReceived) { + labelIDs = append(labelIDs, proton.AllDraftsLabel) + } + + if msg.starred { + labelIDs = append(labelIDs, proton.StarredLabel) + } + + if msg.sysLabel != nil { + if *msg.sysLabel != "" { + labelIDs = append(labelIDs, *msg.sysLabel) + } + } else { + switch { + case msg.flags.Has(proton.MessageFlagReceived): + labelIDs = append(labelIDs, proton.InboxLabel) + + case msg.flags.Has(proton.MessageFlagSent): + labelIDs = append(labelIDs, proton.SentLabel) + + default: + labelIDs = append(labelIDs, proton.DraftsLabel) + } + } + + return proton.MessageMetadata{ + ID: msg.messageID, + ExternalID: msg.externalID, + AddressID: msg.addrID, + LabelIDs: append(msg.labelIDs, labelIDs...), + + Subject: msg.subject, + Sender: msg.sender, + ToList: msg.toList, + CCList: msg.ccList, + BCCList: msg.bccList, + + Flags: msg.flags, + Unread: proton.Bool(msg.unread), + } +} + +func (msg *message) getHeader() string { + builder := new(strings.Builder) + + builder.WriteString("Subject: " + msg.subject + "\r\n") + + if msg.sender != nil { + builder.WriteString("From: " + msg.sender.String() + "\r\n") + } + + if len(msg.toList) > 0 { + builder.WriteString("To: " + toAddressList(msg.toList) + "\r\n") + } + + if len(msg.ccList) > 0 { + builder.WriteString("Cc: " + toAddressList(msg.ccList) + "\r\n") + } + + if len(msg.bccList) > 0 { + builder.WriteString("Bcc: " + toAddressList(msg.bccList) + "\r\n") + } + + if msg.mimeType != "" { + builder.WriteString("Content-Type: " + string(msg.mimeType) + "\r\n") + } + + return builder.String() +} + +func (msg *message) getParsedHeaders() proton.Headers { + header, err := rfc822.NewHeader([]byte(msg.getHeader())) + if err != nil { + panic(err) + } + + parsed := make(proton.Headers) + + header.Entries(func(key, value string) { + parsed[key] = append(parsed[key], value) + }) + + return parsed +} + +// applyChanges will apply non-nil field from passed message. +// +// NOTE: This is not feature complete. It might panic on non-implemented +// changes. +func (msg *message) applyChanges(changes proton.DraftTemplate) { + if changes.Subject != "" { + msg.subject = changes.Subject + } + + if changes.Sender != nil { + panic("sender change probably not allowed by API on existing draft") + } + + if changes.ToList != nil { + msg.toList = append([]*mail.Address{}, changes.ToList...) + } + + if changes.CCList != nil { + msg.ccList = append([]*mail.Address{}, changes.CCList...) + } + + if changes.BCCList != nil { + msg.bccList = append([]*mail.Address{}, changes.BCCList...) + } + + if changes.Body != "" { + msg.armBody = changes.Body + } + + if changes.MIMEType != "" { + msg.mimeType = changes.MIMEType + } + + if changes.ExternalID != "" { + msg.externalID = changes.ExternalID + } +} + +func (msg *message) addLabel(labelID string, labels map[string]*label) { + switch labelID { + case proton.InboxLabel, proton.SentLabel, proton.DraftsLabel: + msg.addFlagLabel(labelID, labels) + + case proton.TrashLabel, proton.SpamLabel, proton.ArchiveLabel: + msg.addSystemLabel(labelID, labels) + + case proton.StarredLabel: + msg.starred = true + + default: + if label, ok := labels[labelID]; ok { + msg.addUserLabel(label, labels) + } + } +} + +func (msg *message) addFlagLabel(labelID string, labels map[string]*label) { + msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { + return labels[otherLabelID].labelType == proton.LabelTypeLabel + }) + + msg.sysLabel = nil +} + +func (msg *message) addSystemLabel(labelID string, labels map[string]*label) { + msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { + return labels[otherLabelID].labelType == proton.LabelTypeLabel + }) + + msg.sysLabel = &labelID +} + +func (msg *message) addUserLabel(label *label, labels map[string]*label) { + if label.labelType != proton.LabelTypeLabel { + msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { + return labels[otherLabelID].labelType == proton.LabelTypeLabel + }) + + msg.sysLabel = pointer("") + } + + if !slices.Contains(msg.labelIDs, label.labelID) { + msg.labelIDs = append(msg.labelIDs, label.labelID) + } +} + +func (msg *message) remLabel(labelID string, labels map[string]*label) { + switch labelID { + case proton.InboxLabel, proton.SentLabel, proton.DraftsLabel: + msg.remFlagLabel(labelID, labels) + + case proton.TrashLabel, proton.SpamLabel, proton.ArchiveLabel: + msg.remSystemLabel(labelID, labels) + + case proton.StarredLabel: + msg.starred = false + + default: + if label, ok := labels[labelID]; ok { + msg.remUserLabel(label, labels) + } + } +} + +func (msg *message) remFlagLabel(labelID string, labels map[string]*label) { + msg.sysLabel = pointer("") +} + +func (msg *message) remSystemLabel(labelID string, labels map[string]*label) { + if msg.sysLabel != nil && *msg.sysLabel == labelID { + msg.sysLabel = pointer("") + } +} + +func (msg *message) remUserLabel(label *label, labels map[string]*label) { + msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { + return otherLabelID != label.labelID + }) +} + +func toAddressList(addrs []*mail.Address) string { + res := make([]string, len(addrs)) + + for i, addr := range addrs { + res[i] = addr.String() + } + + return strings.Join(res, ", ") +} + +func pointer[T any](v T) *T { + return &v +} diff --git a/server/backend/modulus.asc b/server/backend/modulus.asc new file mode 100644 index 0000000..26619ed --- /dev/null +++ b/server/backend/modulus.asc @@ -0,0 +1 @@ ++88jb48lF5TyDBveyHZ7QhSvtc4V3pN8/eQW6kk6ok2egy4lr5Wz9h8iZP3erN9lReSx1Lk+WsLu1b3soDhXX/twTCUhxYwjS8r983aEshZJJq7p5tNroQ5pzrZMbK8Oszjajgdg2YzcMcaJqb9+Doi7egj/esUQ+Q7BWdxeK77Wafj9v7PiW6Ozx6ulppu1mZ+YGnXSXJsl1Cl4nPm7PNkgj4BQT3HLrxakh7Xc3agmepRKO/1jLaOBU/oO17URbA5rwh/ZlAOqEAKH5vJ+hA2acM3Bwsa/K8I/jWicxOoaLZ4RZFpLYvOxGbb4DggR2Ri/C6tNyeEQQKAtxpeV5g== \ No newline at end of file diff --git a/server/backend/modulus.go b/server/backend/modulus.go new file mode 100644 index 0000000..feb77a2 --- /dev/null +++ b/server/backend/modulus.go @@ -0,0 +1,24 @@ +package backend + +import ( + _ "embed" + + "github.com/ProtonMail/gopenpgp/v2/crypto" +) + +var modulus string + +func init() { + arm, err := crypto.NewClearTextMessage(asc, sig).GetArmored() + if err != nil { + panic(err) + } + + modulus = arm +} + +//go:embed modulus.asc +var asc []byte + +//go:embed modulus.sig +var sig []byte diff --git a/server/backend/modulus.sig b/server/backend/modulus.sig new file mode 100644 index 0000000..4518a8e Binary files /dev/null and b/server/backend/modulus.sig differ diff --git a/server/backend/types.go b/server/backend/types.go new file mode 100644 index 0000000..87ad54b --- /dev/null +++ b/server/backend/types.go @@ -0,0 +1,112 @@ +package backend + +import ( + "encoding/base64" + "math/big" + "time" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/google/uuid" +) + +type ID uint64 + +func (v ID) String() string { + return base64.StdEncoding.EncodeToString(v.Bytes()) +} + +func (v ID) Bytes() []byte { + if v == 0 { + return []byte{0} + } + + return new(big.Int).SetUint64(uint64(v)).Bytes() +} + +func (v *ID) FromString(s string) error { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return err + } + + *v = ID(new(big.Int).SetBytes(b).Uint64()) + + return nil +} + +type auth struct { + acc string + ref string + + expiration time.Time + creation time.Time +} + +func newAuth(authLife time.Duration) auth { + return auth{ + acc: uuid.NewString(), + ref: uuid.NewString(), + + expiration: time.Now().Add(authLife), + creation: time.Now(), + } +} + +func (auth *auth) toAuth(userID, authUID string, proof []byte) proton.Auth { + return proton.Auth{ + UserID: userID, + + UID: authUID, + AccessToken: auth.acc, + RefreshToken: auth.ref, + ServerProof: base64.StdEncoding.EncodeToString(proof), + ExpiresIn: int(time.Until(auth.expiration).Seconds()), + + TwoFA: proton.TwoFAInfo{Enabled: proton.TwoFADisabled}, + PasswordMode: proton.OnePasswordMode, + } +} + +func (auth *auth) toAuthSession(authUID string) proton.AuthSession { + return proton.AuthSession{ + UID: authUID, + CreateTime: auth.creation.Unix(), + Revocable: true, + } +} + +type key struct { + keyID string + key string + tok string + sig string +} + +func (key key) unlock(passphrase []byte) (*crypto.KeyRing, error) { + lockedKey, err := crypto.NewKeyFromArmored(key.key) + if err != nil { + return nil, err + } + + unlockedKey, err := lockedKey.Unlock(passphrase) + if err != nil { + return nil, err + } + + return crypto.NewKeyRing(unlockedKey) +} + +func (key key) getPubKey() (*crypto.Key, error) { + privKey, err := crypto.NewKeyFromArmored(key.key) + if err != nil { + return nil, err + } + + pubKeyBin, err := privKey.GetPublicKey() + if err != nil { + return nil, err + } + + return crypto.NewKey(pubKeyBin) +} diff --git a/server/backend/types_test.go b/server/backend/types_test.go new file mode 100644 index 0000000..a5b29da --- /dev/null +++ b/server/backend/types_test.go @@ -0,0 +1,23 @@ +package backend + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestID(t *testing.T) { + var v ID + + // We can set the ID from a string. + require.NoError(t, v.FromString("AQIDBA==")) + + // We can get the ID as a string. + require.Equal(t, "AQIDBA==", v.String()) + + // We can get the ID as bytes. + require.Equal(t, []byte{1, 2, 3, 4}, v.Bytes()) + + // The ID is correct. + require.Equal(t, ID(0x01020304), v) +} diff --git a/server/backend/updates.go b/server/backend/updates.go new file mode 100644 index 0000000..2e643ae --- /dev/null +++ b/server/backend/updates.go @@ -0,0 +1,175 @@ +package backend + +import ( + "github.com/ProtonMail/go-proton-api" + "github.com/bradenaw/juniper/xslices" +) + +func merge(updates []update) []update { + if len(updates) < 2 { + return updates + } + + if merged := merge(updates[1:]); xslices.IndexFunc(merged, func(other update) bool { + return other.replaces(updates[0]) + }) < 0 { + return append([]update{updates[0]}, merged...) + } else { + return merged + } +} + +type update interface { + replaces(other update) bool + + _isUpdate() +} + +type baseUpdate struct{} + +func (baseUpdate) replaces(update) bool { + return false +} + +func (baseUpdate) _isUpdate() {} + +type userRefreshed struct { + baseUpdate + + refresh proton.RefreshFlag +} + +type messageCreated struct { + baseUpdate + messageID string +} + +type messageUpdated struct { + baseUpdate + messageID string +} + +func (update *messageUpdated) replaces(other update) bool { + switch other := other.(type) { + case *messageUpdated: + return update.messageID == other.messageID + + default: + return false + } +} + +type messageDeleted struct { + baseUpdate + messageID string +} + +func (update *messageDeleted) replaces(other update) bool { + switch other := other.(type) { + case *messageCreated: + return update.messageID == other.messageID + + case *messageUpdated: + return update.messageID == other.messageID + + case *messageDeleted: + if update.messageID != other.messageID { + return false + } + + panic("message deleted twice") + + default: + return false + } +} + +type labelCreated struct { + baseUpdate + labelID string +} + +type labelUpdated struct { + baseUpdate + labelID string +} + +func (update *labelUpdated) replaces(other update) bool { + switch other := other.(type) { + case *labelUpdated: + return update.labelID == other.labelID + + default: + return false + } +} + +type labelDeleted struct { + baseUpdate + labelID string +} + +func (update *labelDeleted) replaces(other update) bool { + switch other := other.(type) { + case *labelCreated: + return update.labelID == other.labelID + + case *labelUpdated: + return update.labelID == other.labelID + + case *labelDeleted: + if update.labelID != other.labelID { + return false + } + + panic("label deleted twice") + + default: + return false + } +} + +type addressCreated struct { + baseUpdate + addressID string +} + +type addressUpdated struct { + baseUpdate + addressID string +} + +func (update *addressUpdated) replaces(other update) bool { + switch other := other.(type) { + case *addressUpdated: + return update.addressID == other.addressID + + default: + return false + } +} + +type addressDeleted struct { + baseUpdate + addressID string +} + +func (update *addressDeleted) replaces(other update) bool { + switch other := other.(type) { + case *addressCreated: + return update.addressID == other.addressID + + case *addressUpdated: + return update.addressID == other.addressID + + case *addressDeleted: + if update.addressID != other.addressID { + return false + } + + panic("address deleted twice") + + default: + return false + } +} diff --git a/server/backend/updates_test.go b/server/backend/updates_test.go new file mode 100644 index 0000000..3aebcc2 --- /dev/null +++ b/server/backend/updates_test.go @@ -0,0 +1,63 @@ +package backend + +import ( + "reflect" + "testing" +) + +func Test_mergeUpdates(t *testing.T) { + tests := []struct { + name string + have []update + want []update + }{ + { + name: "single", + have: []update{&labelCreated{labelID: "1"}}, + want: []update{&labelCreated{labelID: "1"}}, + }, + { + name: "multiple", + have: []update{ + &labelCreated{labelID: "1"}, + &labelCreated{labelID: "2"}, + }, + want: []update{ + &labelCreated{labelID: "1"}, + &labelCreated{labelID: "2"}, + }, + }, + { + name: "replace with updated", + have: []update{ + &labelCreated{labelID: "1"}, + &labelUpdated{labelID: "1"}, + &labelUpdated{labelID: "1"}, + }, + want: []update{ + &labelCreated{labelID: "1"}, + &labelUpdated{labelID: "1"}, + }, + }, + { + name: "replace with delete", + have: []update{ + &labelCreated{labelID: "1"}, + &labelUpdated{labelID: "1"}, + &labelUpdated{labelID: "1"}, + &labelDeleted{labelID: "1"}, + }, + want: []update{ + &labelDeleted{labelID: "1"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := merge(tt.have); !reflect.DeepEqual(got, tt.want) { + t.Errorf("mergeUpdates() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/cache.go b/server/cache.go new file mode 100644 index 0000000..db7dfb6 --- /dev/null +++ b/server/cache.go @@ -0,0 +1,52 @@ +package server + +import ( + "sync" + + "github.com/ProtonMail/go-proton-api" +) + +func NewAuthCache() AuthCacher { + return &authCache{ + info: make(map[string]proton.AuthInfo), + auth: make(map[string]proton.Auth), + } +} + +type authCache struct { + info map[string]proton.AuthInfo + auth map[string]proton.Auth + lock sync.RWMutex +} + +func (c *authCache) GetAuthInfo(username string) (proton.AuthInfo, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + + info, ok := c.info[username] + + return info, ok +} + +func (c *authCache) SetAuthInfo(username string, info proton.AuthInfo) { + c.lock.Lock() + defer c.lock.Unlock() + + c.info[username] = info +} + +func (c *authCache) GetAuth(username string) (proton.Auth, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + + auth, ok := c.auth[username] + + return auth, ok +} + +func (c *authCache) SetAuth(username string, auth proton.Auth) { + c.lock.Lock() + defer c.lock.Unlock() + + c.auth[username] = auth +} diff --git a/server/call.go b/server/call.go new file mode 100644 index 0000000..9bd78bc --- /dev/null +++ b/server/call.go @@ -0,0 +1,50 @@ +package server + +import ( + "net/http" + "net/url" +) + +type Call struct { + URL *url.URL + Method string + Status int + + RequestHeader http.Header + RequestBody []byte + + ResponseHeader http.Header + ResponseBody []byte +} + +type callWatcher struct { + paths map[string]struct{} + callFn func(Call) +} + +func newCallWatcher(fn func(Call), paths ...string) callWatcher { + pathMap := make(map[string]struct{}, len(paths)) + + for _, path := range paths { + pathMap[path] = struct{}{} + } + + return callWatcher{ + paths: pathMap, + callFn: fn, + } +} + +func (watcher *callWatcher) isWatching(path string) bool { + if len(watcher.paths) == 0 { + return true + } + + _, ok := watcher.paths[path] + + return ok +} + +func (watcher *callWatcher) publish(call Call) { + watcher.callFn(call) +} diff --git a/server/cmd/client/client.go b/server/cmd/client/client.go new file mode 100644 index 0000000..900a414 --- /dev/null +++ b/server/cmd/client/client.go @@ -0,0 +1,291 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net" + "os" + + "github.com/ProtonMail/go-proton-api/server/proto" + "github.com/urfave/cli/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +func main() { + app := cli.NewApp() + + app.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "host", + Usage: "host to connect to", + Value: "localhost", + }, + &cli.IntFlag{ + Name: "port", + Usage: "port to connect to", + Value: 8080, + }, + } + + app.Commands = []*cli.Command{ + { + Name: "info", + Action: getInfoAction, + }, + { + Name: "auth", + Subcommands: []*cli.Command{ + { + Name: "revoke", + Action: revokeUserAction, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "userID", + Usage: "user ID to revoke", + Required: true, + }, + }, + }, + }, + }, + { + Name: "user", + Subcommands: []*cli.Command{ + { + Name: "create", + Action: createUserAction, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "username", + Usage: "username of the account", + Required: true, + }, + &cli.StringFlag{ + Name: "email", + Usage: "email of the account", + Required: true, + }, + &cli.StringFlag{ + Name: "password", + Usage: "password of the account", + Required: true, + }, + }, + }, + }, + }, + { + Name: "address", + Subcommands: []*cli.Command{ + { + Name: "create", + Action: createAddressAction, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "userID", + Usage: "ID of the user to create the address for", + Required: true, + }, + &cli.StringFlag{ + Name: "email", + Usage: "email of the account", + Required: true, + }, + &cli.StringFlag{ + Name: "password", + Usage: "password of the account", + Required: true, + }, + }, + }, + { + Name: "remove", + Action: removeAddressAction, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "userID", + Usage: "ID of the user to remove the address from", + Required: true, + }, + &cli.StringFlag{ + Name: "addressID", + Usage: "ID of the address to remove", + Required: true, + }, + }, + }, + }, + }, + { + Name: "label", + Subcommands: []*cli.Command{ + { + Name: "create", + Action: createLabelAction, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "userID", + Usage: "ID of the user to create the label for", + Required: true, + }, + &cli.StringFlag{ + Name: "name", + Usage: "name of the label", + Required: true, + }, + &cli.StringFlag{ + Name: "parentID", + Usage: "the ID of the parent label", + }, + &cli.BoolFlag{ + Name: "exclusive", + Usage: "Create an exclusive label (i.e. a folder)", + }, + }, + }, + }, + }, + } + + if err := app.Run(os.Args); err != nil { + log.Fatal(err) + } +} + +func getInfoAction(c *cli.Context) error { + client, err := newServerClient(c) + if err != nil { + return err + } + + res, err := client.GetInfo(c.Context, &proto.GetInfoRequest{}) + if err != nil { + return err + } + + return pretty(c.App.Writer, res) +} + +func createUserAction(c *cli.Context) error { + client, err := newServerClient(c) + if err != nil { + return err + } + + res, err := client.CreateUser(c.Context, &proto.CreateUserRequest{ + Username: c.String("username"), + Email: c.String("email"), + Password: []byte(c.String("password")), + }) + if err != nil { + return err + } + + return pretty(c.App.Writer, res) +} + +func revokeUserAction(c *cli.Context) error { + client, err := newServerClient(c) + if err != nil { + return err + } + + res, err := client.RevokeUser(c.Context, &proto.RevokeUserRequest{ + UserID: c.String("userID"), + }) + if err != nil { + return err + } + + return pretty(c.App.Writer, res) +} + +func createAddressAction(c *cli.Context) error { + client, err := newServerClient(c) + if err != nil { + return err + } + + res, err := client.CreateAddress(c.Context, &proto.CreateAddressRequest{ + UserID: c.String("userID"), + Email: c.String("email"), + Password: []byte(c.String("password")), + }) + if err != nil { + return err + } + + return pretty(c.App.Writer, res) +} + +func removeAddressAction(c *cli.Context) error { + client, err := newServerClient(c) + if err != nil { + return err + } + + res, err := client.RemoveAddress(c.Context, &proto.RemoveAddressRequest{ + UserID: c.String("userID"), + AddrID: c.String("addressID"), + }) + if err != nil { + return err + } + + return pretty(c.App.Writer, res) +} + +func createLabelAction(c *cli.Context) error { + client, err := newServerClient(c) + if err != nil { + return err + } + + var labelType proto.LabelType + + if c.Bool("exclusive") { + labelType = proto.LabelType_FOLDER + } else { + labelType = proto.LabelType_LABEL + } + + res, err := client.CreateLabel(c.Context, &proto.CreateLabelRequest{ + UserID: c.String("userID"), + Name: c.String("name"), + Type: labelType, + }) + if err != nil { + return err + } + + return pretty(c.App.Writer, res) +} + +func newServerClient(c *cli.Context) (proto.ServerClient, error) { + cc, err := grpc.DialContext( + c.Context, + net.JoinHostPort(c.String("host"), fmt.Sprint(c.Int("port"))), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("failed to dial: %w", err) + } + + return proto.NewServerClient(cc), nil +} + +func pretty[T any](w io.Writer, v T) error { + enc, err := json.MarshalIndent(v, "", " ") + if err != nil { + return err + } + + if _, err := w.Write(enc); err != nil { + return err + } + + return nil +} diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go new file mode 100644 index 0000000..533ddee --- /dev/null +++ b/server/cmd/server/main.go @@ -0,0 +1,140 @@ +package main + +import ( + "context" + "fmt" + "log" + "net" + "os" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server" + "github.com/ProtonMail/go-proton-api/server/proto" + "github.com/urfave/cli/v2" + "google.golang.org/grpc" +) + +func main() { + app := cli.NewApp() + + app.Flags = []cli.Flag{ + &cli.IntFlag{ + Name: "port", + Aliases: []string{"p"}, + Usage: "port to serve gRPC on", + Value: 8080, + }, + &cli.BoolFlag{ + Name: "tls", + }, + } + + app.Action = run + + if err := app.Run(os.Args); err != nil { + log.Fatal(err) + } +} + +func run(c *cli.Context) error { + s := server.New(server.WithTLS(c.Bool("tls"))) + defer s.Close() + + return newService(s).run(c.Int("port")) +} + +type service struct { + proto.UnimplementedServerServer + + server *server.Server + + gRPCServer *grpc.Server +} + +func newService(server *server.Server) *service { + s := &service{ + server: server, + + gRPCServer: grpc.NewServer(), + } + + proto.RegisterServerServer(s.gRPCServer, s) + + return s +} + +func (s *service) GetInfo(ctx context.Context, req *proto.GetInfoRequest) (*proto.GetInfoResponse, error) { + return &proto.GetInfoResponse{ + HostURL: s.server.GetHostURL(), + ProxyURL: s.server.GetProxyURL(), + }, nil +} + +func (s *service) CreateUser(ctx context.Context, req *proto.CreateUserRequest) (*proto.CreateUserResponse, error) { + userID, addrID, err := s.server.CreateUser(req.Username, req.Email, req.Password) + if err != nil { + return nil, err + } + + return &proto.CreateUserResponse{ + UserID: userID, + AddrID: addrID, + }, nil +} + +func (s *service) RevokeUser(ctx context.Context, req *proto.RevokeUserRequest) (*proto.RevokeUserResponse, error) { + if err := s.server.RevokeUser(req.UserID); err != nil { + return nil, err + } + + return &proto.RevokeUserResponse{}, nil +} + +func (s *service) CreateAddress(ctx context.Context, req *proto.CreateAddressRequest) (*proto.CreateAddressResponse, error) { + addrID, err := s.server.CreateAddress(req.UserID, req.Email, req.Password) + if err != nil { + return nil, err + } + + return &proto.CreateAddressResponse{ + AddrID: addrID, + }, nil +} + +func (s *service) RemoveAddress(ctx context.Context, req *proto.RemoveAddressRequest) (*proto.RemoveAddressResponse, error) { + if err := s.server.RemoveAddress(req.UserID, req.AddrID); err != nil { + return nil, err + } + + return &proto.RemoveAddressResponse{}, nil +} + +func (s *service) CreateLabel(ctx context.Context, req *proto.CreateLabelRequest) (*proto.CreateLabelResponse, error) { + var labelType proton.LabelType + + switch req.Type { + case proto.LabelType_FOLDER: + labelType = proton.LabelTypeFolder + + case proto.LabelType_LABEL: + labelType = proton.LabelTypeLabel + } + + labelID, err := s.server.CreateLabel(req.UserID, req.Name, req.ParentID, labelType) + if err != nil { + return nil, err + } + + return &proto.CreateLabelResponse{ + LabelID: labelID, + }, nil +} + +func (s *service) run(port int) error { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + return err + } + + return s.gRPCServer.Serve(listener) +} diff --git a/server/contacts.go b/server/contacts.go new file mode 100644 index 0000000..b1cc226 --- /dev/null +++ b/server/contacts.go @@ -0,0 +1,16 @@ +package server + +import ( + "net/http" + + "github.com/ProtonMail/go-proton-api" + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetContactsEmails() gin.HandlerFunc { + return func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "ContactEmails": []proton.ContactEmail{}, + }) + } +} diff --git a/server/errors.go b/server/errors.go new file mode 100644 index 0000000..99dec48 --- /dev/null +++ b/server/errors.go @@ -0,0 +1,9 @@ +package server + +import "errors" + +var ( + ErrNoSuchUser = errors.New("no such user") + ErrNoSuchAddress = errors.New("no such address") + ErrNoSuchLabel = errors.New("no such label") +) diff --git a/server/events.go b/server/events.go new file mode 100644 index 0000000..0052763 --- /dev/null +++ b/server/events.go @@ -0,0 +1,33 @@ +package server + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetEvents() gin.HandlerFunc { + return func(c *gin.Context) { + event, err := s.b.GetEvent(c.GetString("UserID"), c.Param("eventID")) + if err != nil { + _ = c.AbortWithError(http.StatusBadRequest, err) + return + } + + c.JSON(http.StatusOK, event) + } +} + +func (s *Server) handleGetEventsLatest() gin.HandlerFunc { + return func(c *gin.Context) { + eventID, err := s.b.GetLatestEventID(c.GetString("UserID")) + if err != nil { + _ = c.AbortWithError(http.StatusBadRequest, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "EventID": eventID, + }) + } +} diff --git a/server/keys.go b/server/keys.go new file mode 100644 index 0000000..ab09648 --- /dev/null +++ b/server/keys.go @@ -0,0 +1,37 @@ +package server + +import ( + "net/http" + + "github.com/ProtonMail/go-proton-api" + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetKeys() gin.HandlerFunc { + return func(c *gin.Context) { + if pubKeys, err := s.b.GetPublicKeys(c.Query("Email")); err == nil && len(pubKeys) > 0 { + c.JSON(http.StatusOK, gin.H{ + "Keys": pubKeys, + "RecipientType": proton.RecipientTypeInternal, + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "RecipientType": proton.RecipientTypeExternal, + }) + } + } +} + +func (s *Server) handleGetKeySalts() gin.HandlerFunc { + return func(c *gin.Context) { + salts, err := s.b.GetKeySalts(c.GetString("UserID")) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "KeySalts": salts, + }) + } +} diff --git a/server/labels.go b/server/labels.go new file mode 100644 index 0000000..c2efab8 --- /dev/null +++ b/server/labels.go @@ -0,0 +1,100 @@ +package server + +import ( + "net/http" + "strconv" + + "github.com/ProtonMail/go-proton-api" + "github.com/bradenaw/juniper/xslices" + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetMailLabels() gin.HandlerFunc { + return func(c *gin.Context) { + types := xslices.Map(c.QueryArray("Type"), func(val string) proton.LabelType { + labelType, err := strconv.Atoi(val) + if err != nil { + panic(err) + } + + return proton.LabelType(labelType) + }) + + labels, err := s.b.GetLabels(c.GetString("UserID"), types...) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Labels": labels, + }) + } +} + +func (s *Server) handlePostMailLabels() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.CreateLabelReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + if _, has, err := s.b.HasLabel(c.GetString("UserID"), req.Name); err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } else if has { + c.AbortWithStatus(http.StatusConflict) + return + } + + label, err := s.b.CreateLabel(c.GetString("UserID"), req.Name, req.ParentID, req.Type) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Label": label, + }) + } +} + +func (s *Server) handlePutMailLabel() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.UpdateLabelReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + if labelID, has, err := s.b.HasLabel(c.GetString("UserID"), req.Name); err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } else if has && labelID != c.Param("labelID") { + c.AbortWithStatus(http.StatusConflict) + return + } + + label, err := s.b.UpdateLabel(c.GetString("UserID"), c.Param("labelID"), req.Name, req.ParentID) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Label": label, + }) + } +} + +func (s *Server) handleDeleteMailLabel() gin.HandlerFunc { + return func(c *gin.Context) { + if err := s.b.DeleteLabel(c.GetString("UserID"), c.Param("labelID")); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + } +} diff --git a/server/main_test.go b/server/main_test.go new file mode 100644 index 0000000..560a4ad --- /dev/null +++ b/server/main_test.go @@ -0,0 +1,11 @@ +package server + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, goleak.IgnoreCurrent()) +} diff --git a/server/messages.go b/server/messages.go new file mode 100644 index 0000000..bfb907c --- /dev/null +++ b/server/messages.go @@ -0,0 +1,573 @@ +package server + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "mime" + "net/http" + "net/mail" + "strconv" + + "github.com/ProtonMail/gluon/rfc822" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/gin-gonic/gin" + "golang.org/x/exp/slices" +) + +func (s *Server) handleGetMailMessages() gin.HandlerFunc { + return func(c *gin.Context) { + filter := proton.MessageFilter{ + ID: c.QueryArray("ID"), + } + s.getMailMessages( + c, + mustParseInt(c.DefaultQuery("Page", "0")), + mustParseInt(c.DefaultQuery("PageSize", "100")), + filter, + ) + } +} + +func (s *Server) getMailMessages(c *gin.Context, page, pageSize int, filter proton.MessageFilter) { + messages, err := s.b.GetMessages(c.GetString("UserID"), page, pageSize, filter) + if err != nil { + _ = c.AbortWithError(http.StatusInternalServerError, err) + return + } + + total, err := s.b.CountMessages(c.GetString("UserID")) + if err != nil { + _ = c.AbortWithError(http.StatusInternalServerError, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Messages": messages, + "Total": total, + "Stale": proton.APIFalse, + }) +} + +func (s *Server) handlePostMailMessages() gin.HandlerFunc { + return func(c *gin.Context) { + switch c.GetHeader("X-HTTP-Method-Override") { + case "GET": + var req struct { + proton.MessageFilter + + Page int + PageSize int + } + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + s.getMailMessages(c, req.Page, req.PageSize, req.MessageFilter) + + default: + s.postMailMessages(c) + } + } +} + +func (s *Server) postMailMessages(c *gin.Context) { + var req proton.CreateDraftReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + addrID, err := s.b.GetAddressID(req.Message.Sender.Address) + if err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + + message, err := s.b.CreateDraft( + c.GetString("UserID"), + addrID, + req.Message.Subject, + req.Message.Sender, + req.Message.ToList, + req.Message.CCList, + req.Message.BCCList, + req.Message.Body, + req.Message.MIMEType, + req.Message.ExternalID, + ) + if err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Message": message, + }) +} + +func (s *Server) handleGetMailMessageIDs() gin.HandlerFunc { + return func(c *gin.Context) { + limit, err := strconv.Atoi(c.Query("Limit")) + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + messageIDs, err := s.b.GetMessageIDs(c.GetString("UserID"), c.Query("AfterID"), limit) + if err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + + c.JSON(http.StatusOK, gin.H{ + "IDs": messageIDs, + }) + } +} + +func (s *Server) handleGetMailMessage() gin.HandlerFunc { + return func(c *gin.Context) { + message, err := s.b.GetMessage(c.GetString("UserID"), c.Param("messageID")) + if err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Message": message, + }) + } +} + +func (s *Server) handlePostMailMessage() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.SendDraftReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + message, err := s.b.SendMessage(c.GetString("UserID"), c.Param("messageID"), req.Packages) + if err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Sent": message, + }) + } +} + +func (s *Server) handlePutMailMessagesRead() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.MessageActionReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + if err := s.b.SetMessagesRead(c.GetString("UserID"), true, req.IDs...); err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + } +} + +func (s *Server) handlePutMailMessagesUnread() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.MessageActionReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + if err := s.b.SetMessagesRead(c.GetString("UserID"), false, req.IDs...); err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + } +} + +func (s *Server) handlePutMailMessagesLabel() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.LabelMessagesReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + if err := s.b.LabelMessages(c.GetString("UserID"), req.LabelID, req.IDs...); err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + } +} + +func (s *Server) handlePutMailMessagesUnlabel() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.LabelMessagesReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + if err := s.b.UnlabelMessages(c.GetString("UserID"), req.LabelID, req.IDs...); err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + } +} + +func (s *Server) handlePutMailMessagesImport() gin.HandlerFunc { + return func(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + var metadata map[string]proton.ImportMetadata + + if err := json.Unmarshal([]byte(form.Value["Metadata"][0]), &metadata); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + files := make(map[string][]byte) + + for name, file := range form.File { + files[name] = mustReadFileHeader(file[0]) + } + + type response struct { + Name string + Response proton.ImportRes + } + + var responses []response + + for name, literal := range files { + res := response{Name: name} + + messageID, err := s.importMessage( + c.GetString("UserID"), + metadata[name].AddressID, + metadata[name].LabelIDs, + literal, + metadata[name].Flags, + bool(metadata[name].Unread), + ) + if err != nil { + res.Response = proton.ImportRes{ + Error: proton.Error{ + Code: proton.InvalidValue, + Message: fmt.Sprintf("failed to import: %v", err), + }, + } + } else { + res.Response = proton.ImportRes{ + Error: proton.Error{Code: proton.SuccessCode}, + MessageID: messageID, + } + } + + responses = append(responses, res) + } + + c.JSON(http.StatusOK, gin.H{ + "Code": proton.MultiCode, + "Responses": responses, + }) + } +} + +func (s *Server) handleDeleteMailMessages() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.MessageActionReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + for _, messageID := range req.IDs { + if err := s.b.DeleteMessage(c.GetString("UserID"), messageID); err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + } + } +} + +func (s *Server) importMessage( + userID, addrID string, + labelIDs []string, + literal []byte, + flags proton.MessageFlag, + unread bool, +) (string, error) { + var exclusive int + + for _, labelID := range labelIDs { + switch labelID { + case proton.AllDraftsLabel, proton.AllSentLabel, proton.AllMailLabel, proton.OutboxLabel: + return "", fmt.Errorf("invalid label ID: %s", labelID) + } + + label, err := s.b.GetLabel(userID, labelID) + if err != nil { + return "", fmt.Errorf("invalid label ID: %s", labelID) + } + + if label.Type != proton.LabelTypeLabel { + exclusive++ + } + } + + if exclusive > 1 { + return "", fmt.Errorf("too many exclusive labels") + } + + header, body, atts, mimeType, err := s.parseMessage(literal) + if err != nil { + return "", fmt.Errorf("failed to parse message: %w", err) + } + + messageID, err := s.importBody(userID, addrID, header, body, mimeType, flags, unread, slices.Contains(labelIDs, proton.StarredLabel)) + if err != nil { + return "", fmt.Errorf("failed to import message: %w", err) + } + + for _, att := range atts { + if _, err := s.importAttachment(userID, messageID, att); err != nil { + return "", fmt.Errorf("failed to import attachment: %w", err) + } + } + + for _, labelID := range labelIDs { + if err := s.b.LabelMessages(userID, labelID, messageID); err != nil { + return "", fmt.Errorf("failed to label message: %w", err) + } + } + + return messageID, nil +} + +func (s *Server) parseMessage(literal []byte) (*rfc822.Header, []string, []*rfc822.Section, rfc822.MIMEType, error) { + root := rfc822.Parse(literal) + + header, err := root.ParseHeader() + if err != nil { + return nil, nil, nil, "", fmt.Errorf("failed to parse header: %w", err) + } + + body, atts, err := collect(root) + if err != nil { + return nil, nil, nil, "", fmt.Errorf("failed to collect body and attachments: %w", err) + } + + mimeType, _, err := root.ContentType() + if err != nil { + return nil, nil, nil, "", fmt.Errorf("failed to parse content type: %w", err) + } + + // Force all multipart types to be multipart/mixed. + if mimeType.Type() == "multipart" { + mimeType = "multipart/mixed" + } + + return header, body, atts, mimeType, nil +} + +func collect(section *rfc822.Section) ([]string, []*rfc822.Section, error) { + mimeType, _, err := section.ContentType() + if err != nil { + return nil, nil, fmt.Errorf("failed to parse content type: %w", err) + } + + switch mimeType.Type() { + case "text": + return []string{string(section.Body())}, nil, nil + + case "multipart": + children, err := section.Children() + if err != nil { + return nil, nil, fmt.Errorf("failed to parse children: %w", err) + } + + switch mimeType.SubType() { + case "encrypted": + if len(children) != 2 { + return nil, nil, fmt.Errorf("expected two children for multipart/encrypted, got %d", len(children)) + } + + return []string{string(children[1].Body())}, nil, nil + + default: + var ( + multiBody []string + multiAtts []*rfc822.Section + ) + + for _, child := range children { + body, atts, err := collect(child) + if err != nil { + return nil, nil, fmt.Errorf("failed to collect child: %w", err) + } + + multiBody = append(multiBody, body...) + multiAtts = append(multiAtts, atts...) + } + + return multiBody, multiAtts, nil + } + + default: + return nil, []*rfc822.Section{section}, nil + } +} + +func (s *Server) importBody( + userID, addrID string, + header *rfc822.Header, + body []string, + mimeType rfc822.MIMEType, + flags proton.MessageFlag, + unread, starred bool, +) (string, error) { + subject := header.Get("Subject") + sender := tryParseAddress(header.Get("From")) + toList := tryParseAddressList(header.Get("To")) + ccList := tryParseAddressList(header.Get("Cc")) + bccList := tryParseAddressList(header.Get("Bcc")) + + // NOTE: Importing just the first body part matches API behaviour but sucks! + return s.b.CreateMessage( + userID, addrID, + subject, + sender, + toList, ccList, bccList, + string(body[0]), + rfc822.MIMEType(mimeType), + flags, + unread, starred, + ) +} + +func (s *Server) importAttachment(userID, messageID string, att *rfc822.Section) (proton.Attachment, error) { + header, err := att.ParseHeader() + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to parse attachment header: %w", err) + } + + mimeType, _, err := att.ContentType() + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to parse attachment content type: %w", err) + } + + var disposition, filename string + + if header.Has("Content-Disposition") { + dispType, dispParams, err := mime.ParseMediaType(header.Get("Content-Disposition")) + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to parse attachment content disposition: %w", err) + } + + disposition = dispType + filename = dispParams["filename"] + } else { + disposition = "attachment" + filename = "attachment.bin" + } + + var body *crypto.PGPSplitMessage + + if header.Get("Content-Transfer-Encoding") == "base64" { + b := make([]byte, base64.StdEncoding.DecodedLen(len(att.Body()))) + + n, err := base64.StdEncoding.Decode(b, att.Body()) + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to decode attachment body: %w", err) + } + + split, err := crypto.NewPGPMessage(b[:n]).SplitMessage() + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to split attachment body: %w", err) + } + + body = split + } else { + msg, err := crypto.NewPGPMessageFromArmored(string(att.Body())) + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to parse attachment body: %w", err) + } + + split, err := msg.SplitMessage() + if err != nil { + return proton.Attachment{}, fmt.Errorf("failed to split attachment body: %w", err) + } + + body = split + } + + // TODO: What about the signature? + return s.b.CreateAttachment( + userID, messageID, + filename, + mimeType, + proton.Disposition(disposition), + body.GetBinaryKeyPacket(), + body.GetBinaryDataPacket(), + "", + ) +} + +func tryParseAddress(s string) *mail.Address { + if s == "" { + return nil + } + + addr, err := mail.ParseAddress(s) + if err != nil { + return &mail.Address{ + Name: s, + } + } + + return addr +} + +func tryParseAddressList(s string) []*mail.Address { + if s == "" { + return nil + } + + addrs, err := mail.ParseAddressList(s) + if err != nil { + return []*mail.Address{{ + Name: s, + }} + } + + return addrs +} + +func mustParseInt(num string) int { + val, err := strconv.Atoi(num) + if err != nil { + panic(err) + } + + return val +} diff --git a/server/ping.go b/server/ping.go new file mode 100644 index 0000000..e2fb7aa --- /dev/null +++ b/server/ping.go @@ -0,0 +1,7 @@ +package server + +import "github.com/gin-gonic/gin" + +func (s *Server) handleGetPing() gin.HandlerFunc { + return func(c *gin.Context) {} +} diff --git a/server/proto/server.go b/server/proto/server.go new file mode 100644 index 0000000..e5837c1 --- /dev/null +++ b/server/proto/server.go @@ -0,0 +1,3 @@ +package proto + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative server.proto diff --git a/server/proto/server.pb.go b/server/proto/server.pb.go new file mode 100644 index 0000000..2a4b3c0 --- /dev/null +++ b/server/proto/server.pb.go @@ -0,0 +1,993 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.28.1 +// protoc v3.21.7 +// source: server.proto + +package proto + +import ( + reflect "reflect" + sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type LabelType int32 + +const ( + LabelType_FOLDER LabelType = 0 + LabelType_LABEL LabelType = 1 +) + +// Enum value maps for LabelType. +var ( + LabelType_name = map[int32]string{ + 0: "FOLDER", + 1: "LABEL", + } + LabelType_value = map[string]int32{ + "FOLDER": 0, + "LABEL": 1, + } +) + +func (x LabelType) Enum() *LabelType { + p := new(LabelType) + *p = x + return p +} + +func (x LabelType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (LabelType) Descriptor() protoreflect.EnumDescriptor { + return file_server_proto_enumTypes[0].Descriptor() +} + +func (LabelType) Type() protoreflect.EnumType { + return &file_server_proto_enumTypes[0] +} + +func (x LabelType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use LabelType.Descriptor instead. +func (LabelType) EnumDescriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{0} +} + +type GetInfoRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *GetInfoRequest) Reset() { + *x = GetInfoRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetInfoRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetInfoRequest) ProtoMessage() {} + +func (x *GetInfoRequest) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetInfoRequest.ProtoReflect.Descriptor instead. +func (*GetInfoRequest) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{0} +} + +type GetInfoResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + HostURL string `protobuf:"bytes,1,opt,name=hostURL,proto3" json:"hostURL,omitempty"` + ProxyURL string `protobuf:"bytes,2,opt,name=proxyURL,proto3" json:"proxyURL,omitempty"` +} + +func (x *GetInfoResponse) Reset() { + *x = GetInfoResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetInfoResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetInfoResponse) ProtoMessage() {} + +func (x *GetInfoResponse) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetInfoResponse.ProtoReflect.Descriptor instead. +func (*GetInfoResponse) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{1} +} + +func (x *GetInfoResponse) GetHostURL() string { + if x != nil { + return x.HostURL + } + return "" +} + +func (x *GetInfoResponse) GetProxyURL() string { + if x != nil { + return x.ProxyURL + } + return "" +} + +type CreateUserRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + Email string `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"` + Password []byte `protobuf:"bytes,3,opt,name=password,proto3" json:"password,omitempty"` +} + +func (x *CreateUserRequest) Reset() { + *x = CreateUserRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateUserRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateUserRequest) ProtoMessage() {} + +func (x *CreateUserRequest) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateUserRequest.ProtoReflect.Descriptor instead. +func (*CreateUserRequest) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{2} +} + +func (x *CreateUserRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *CreateUserRequest) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + +func (x *CreateUserRequest) GetPassword() []byte { + if x != nil { + return x.Password + } + return nil +} + +type CreateUserResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` + AddrID string `protobuf:"bytes,2,opt,name=addrID,proto3" json:"addrID,omitempty"` +} + +func (x *CreateUserResponse) Reset() { + *x = CreateUserResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateUserResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateUserResponse) ProtoMessage() {} + +func (x *CreateUserResponse) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateUserResponse.ProtoReflect.Descriptor instead. +func (*CreateUserResponse) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{3} +} + +func (x *CreateUserResponse) GetUserID() string { + if x != nil { + return x.UserID + } + return "" +} + +func (x *CreateUserResponse) GetAddrID() string { + if x != nil { + return x.AddrID + } + return "" +} + +type RevokeUserRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` +} + +func (x *RevokeUserRequest) Reset() { + *x = RevokeUserRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RevokeUserRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeUserRequest) ProtoMessage() {} + +func (x *RevokeUserRequest) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeUserRequest.ProtoReflect.Descriptor instead. +func (*RevokeUserRequest) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{4} +} + +func (x *RevokeUserRequest) GetUserID() string { + if x != nil { + return x.UserID + } + return "" +} + +type RevokeUserResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *RevokeUserResponse) Reset() { + *x = RevokeUserResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RevokeUserResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeUserResponse) ProtoMessage() {} + +func (x *RevokeUserResponse) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeUserResponse.ProtoReflect.Descriptor instead. +func (*RevokeUserResponse) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{5} +} + +type CreateAddressRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` + Email string `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"` + Password []byte `protobuf:"bytes,3,opt,name=password,proto3" json:"password,omitempty"` +} + +func (x *CreateAddressRequest) Reset() { + *x = CreateAddressRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateAddressRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateAddressRequest) ProtoMessage() {} + +func (x *CreateAddressRequest) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateAddressRequest.ProtoReflect.Descriptor instead. +func (*CreateAddressRequest) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{6} +} + +func (x *CreateAddressRequest) GetUserID() string { + if x != nil { + return x.UserID + } + return "" +} + +func (x *CreateAddressRequest) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + +func (x *CreateAddressRequest) GetPassword() []byte { + if x != nil { + return x.Password + } + return nil +} + +type CreateAddressResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AddrID string `protobuf:"bytes,1,opt,name=addrID,proto3" json:"addrID,omitempty"` +} + +func (x *CreateAddressResponse) Reset() { + *x = CreateAddressResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateAddressResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateAddressResponse) ProtoMessage() {} + +func (x *CreateAddressResponse) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateAddressResponse.ProtoReflect.Descriptor instead. +func (*CreateAddressResponse) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{7} +} + +func (x *CreateAddressResponse) GetAddrID() string { + if x != nil { + return x.AddrID + } + return "" +} + +type RemoveAddressRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` + AddrID string `protobuf:"bytes,2,opt,name=addrID,proto3" json:"addrID,omitempty"` +} + +func (x *RemoveAddressRequest) Reset() { + *x = RemoveAddressRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RemoveAddressRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RemoveAddressRequest) ProtoMessage() {} + +func (x *RemoveAddressRequest) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RemoveAddressRequest.ProtoReflect.Descriptor instead. +func (*RemoveAddressRequest) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{8} +} + +func (x *RemoveAddressRequest) GetUserID() string { + if x != nil { + return x.UserID + } + return "" +} + +func (x *RemoveAddressRequest) GetAddrID() string { + if x != nil { + return x.AddrID + } + return "" +} + +type RemoveAddressResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *RemoveAddressResponse) Reset() { + *x = RemoveAddressResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RemoveAddressResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RemoveAddressResponse) ProtoMessage() {} + +func (x *RemoveAddressResponse) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RemoveAddressResponse.ProtoReflect.Descriptor instead. +func (*RemoveAddressResponse) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{9} +} + +type CreateLabelRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + ParentID string `protobuf:"bytes,3,opt,name=parentID,proto3" json:"parentID,omitempty"` + Type LabelType `protobuf:"varint,4,opt,name=type,proto3,enum=proto.LabelType" json:"type,omitempty"` +} + +func (x *CreateLabelRequest) Reset() { + *x = CreateLabelRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateLabelRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateLabelRequest) ProtoMessage() {} + +func (x *CreateLabelRequest) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateLabelRequest.ProtoReflect.Descriptor instead. +func (*CreateLabelRequest) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{10} +} + +func (x *CreateLabelRequest) GetUserID() string { + if x != nil { + return x.UserID + } + return "" +} + +func (x *CreateLabelRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *CreateLabelRequest) GetParentID() string { + if x != nil { + return x.ParentID + } + return "" +} + +func (x *CreateLabelRequest) GetType() LabelType { + if x != nil { + return x.Type + } + return LabelType_FOLDER +} + +type CreateLabelResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + LabelID string `protobuf:"bytes,1,opt,name=labelID,proto3" json:"labelID,omitempty"` +} + +func (x *CreateLabelResponse) Reset() { + *x = CreateLabelResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_server_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CreateLabelResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CreateLabelResponse) ProtoMessage() {} + +func (x *CreateLabelResponse) ProtoReflect() protoreflect.Message { + mi := &file_server_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CreateLabelResponse.ProtoReflect.Descriptor instead. +func (*CreateLabelResponse) Descriptor() ([]byte, []int) { + return file_server_proto_rawDescGZIP(), []int{11} +} + +func (x *CreateLabelResponse) GetLabelID() string { + if x != nil { + return x.LabelID + } + return "" +} + +var File_server_proto protoreflect.FileDescriptor + +var file_server_proto_rawDesc = []byte{ + 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x10, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x47, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x68, 0x6f, + 0x73, 0x74, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x68, 0x6f, 0x73, + 0x74, 0x55, 0x52, 0x4c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c, + 0x22, 0x61, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, + 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, + 0x6f, 0x72, 0x64, 0x22, 0x44, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, + 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, + 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, + 0x44, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x2b, 0x0a, 0x11, 0x52, 0x65, 0x76, + 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, + 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, + 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x60, 0x0a, 0x14, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, + 0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x2f, + 0x0a, 0x15, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, + 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, + 0x46, 0x0a, 0x14, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, + 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, + 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x17, 0x0a, 0x15, 0x52, 0x65, 0x6d, 0x6f, 0x76, + 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x82, 0x01, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, + 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, + 0x24, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x52, + 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x2f, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, + 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, + 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, + 0x61, 0x62, 0x65, 0x6c, 0x49, 0x44, 0x2a, 0x22, 0x0a, 0x09, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, + 0x79, 0x70, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x46, 0x4f, 0x4c, 0x44, 0x45, 0x52, 0x10, 0x00, 0x12, + 0x09, 0x0a, 0x05, 0x4c, 0x41, 0x42, 0x45, 0x4c, 0x10, 0x01, 0x32, 0xa6, 0x03, 0x0a, 0x06, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x38, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x12, 0x15, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x41, 0x0a, 0x0a, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x41, 0x0a, 0x0a, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, + 0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, + 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, + 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, + 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, + 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, + 0x0b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x19, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, + 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x42, 0x2e, 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x6e, 0x74, 0x65, 0x63, 0x68, 0x2e, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x2f, 0x6c, + 0x69, 0x74, 0x65, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_server_proto_rawDescOnce sync.Once + file_server_proto_rawDescData = file_server_proto_rawDesc +) + +func file_server_proto_rawDescGZIP() []byte { + file_server_proto_rawDescOnce.Do(func() { + file_server_proto_rawDescData = protoimpl.X.CompressGZIP(file_server_proto_rawDescData) + }) + return file_server_proto_rawDescData +} + +var file_server_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_server_proto_msgTypes = make([]protoimpl.MessageInfo, 12) +var file_server_proto_goTypes = []interface{}{ + (LabelType)(0), // 0: proto.LabelType + (*GetInfoRequest)(nil), // 1: proto.GetInfoRequest + (*GetInfoResponse)(nil), // 2: proto.GetInfoResponse + (*CreateUserRequest)(nil), // 3: proto.CreateUserRequest + (*CreateUserResponse)(nil), // 4: proto.CreateUserResponse + (*RevokeUserRequest)(nil), // 5: proto.RevokeUserRequest + (*RevokeUserResponse)(nil), // 6: proto.RevokeUserResponse + (*CreateAddressRequest)(nil), // 7: proto.CreateAddressRequest + (*CreateAddressResponse)(nil), // 8: proto.CreateAddressResponse + (*RemoveAddressRequest)(nil), // 9: proto.RemoveAddressRequest + (*RemoveAddressResponse)(nil), // 10: proto.RemoveAddressResponse + (*CreateLabelRequest)(nil), // 11: proto.CreateLabelRequest + (*CreateLabelResponse)(nil), // 12: proto.CreateLabelResponse +} +var file_server_proto_depIdxs = []int32{ + 0, // 0: proto.CreateLabelRequest.type:type_name -> proto.LabelType + 1, // 1: proto.Server.GetInfo:input_type -> proto.GetInfoRequest + 3, // 2: proto.Server.CreateUser:input_type -> proto.CreateUserRequest + 5, // 3: proto.Server.RevokeUser:input_type -> proto.RevokeUserRequest + 7, // 4: proto.Server.CreateAddress:input_type -> proto.CreateAddressRequest + 9, // 5: proto.Server.RemoveAddress:input_type -> proto.RemoveAddressRequest + 11, // 6: proto.Server.CreateLabel:input_type -> proto.CreateLabelRequest + 2, // 7: proto.Server.GetInfo:output_type -> proto.GetInfoResponse + 4, // 8: proto.Server.CreateUser:output_type -> proto.CreateUserResponse + 6, // 9: proto.Server.RevokeUser:output_type -> proto.RevokeUserResponse + 8, // 10: proto.Server.CreateAddress:output_type -> proto.CreateAddressResponse + 10, // 11: proto.Server.RemoveAddress:output_type -> proto.RemoveAddressResponse + 12, // 12: proto.Server.CreateLabel:output_type -> proto.CreateLabelResponse + 7, // [7:13] is the sub-list for method output_type + 1, // [1:7] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_server_proto_init() } +func file_server_proto_init() { + if File_server_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_server_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetInfoRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetInfoResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateUserRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateUserResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RevokeUserRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RevokeUserResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateAddressRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateAddressResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RemoveAddressRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RemoveAddressResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateLabelRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_server_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CreateLabelResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_server_proto_rawDesc, + NumEnums: 1, + NumMessages: 12, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_server_proto_goTypes, + DependencyIndexes: file_server_proto_depIdxs, + EnumInfos: file_server_proto_enumTypes, + MessageInfos: file_server_proto_msgTypes, + }.Build() + File_server_proto = out.File + file_server_proto_rawDesc = nil + file_server_proto_goTypes = nil + file_server_proto_depIdxs = nil +} diff --git a/server/proto/server.proto b/server/proto/server.proto new file mode 100644 index 0000000..1509894 --- /dev/null +++ b/server/proto/server.proto @@ -0,0 +1,84 @@ +syntax = "proto3"; + +option go_package = "github.com/ProtonMail/go-proton-api/server/proto"; + +package proto; + +//********************************************************************************************************************** +// Service Declaration +//********************************************************************************************************************** +service Server { + rpc GetInfo (GetInfoRequest) returns (GetInfoResponse); + + rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); + + rpc RevokeUser(RevokeUserRequest) returns (RevokeUserResponse); + + rpc CreateAddress(CreateAddressRequest) returns (CreateAddressResponse); + + rpc RemoveAddress(RemoveAddressRequest) returns (RemoveAddressResponse); + + rpc CreateLabel(CreateLabelRequest) returns (CreateLabelResponse); +} + +//********************************************************************************************************************** + +message GetInfoRequest { +} + +message GetInfoResponse { + string hostURL = 1; + string proxyURL = 2; +} + +message CreateUserRequest { + string username = 1; + string email = 2; + bytes password = 3; +} + +message CreateUserResponse { + string userID = 1; + string addrID = 2; +} + +message RevokeUserRequest { + string userID = 1; +} + +message RevokeUserResponse { +} + +message CreateAddressRequest { + string userID = 1; + string email = 2; + bytes password = 3; +} + +message CreateAddressResponse { + string addrID = 1; +} + +message RemoveAddressRequest { + string userID = 1; + string addrID = 2; +} + +message RemoveAddressResponse { +} + +enum LabelType { + FOLDER = 0; + LABEL = 1; +} + +message CreateLabelRequest { + string userID = 1; + string name = 2; + string parentID = 3; + LabelType type = 4; +} + +message CreateLabelResponse { + string labelID = 1; +} diff --git a/server/proto/server_grpc.pb.go b/server/proto/server_grpc.pb.go new file mode 100644 index 0000000..f5dee5b --- /dev/null +++ b/server/proto/server_grpc.pb.go @@ -0,0 +1,286 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.21.7 +// source: server.proto + +package proto + +import ( + context "context" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// ServerClient is the client API for Server service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ServerClient interface { + GetInfo(ctx context.Context, in *GetInfoRequest, opts ...grpc.CallOption) (*GetInfoResponse, error) + CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) + RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*RevokeUserResponse, error) + CreateAddress(ctx context.Context, in *CreateAddressRequest, opts ...grpc.CallOption) (*CreateAddressResponse, error) + RemoveAddress(ctx context.Context, in *RemoveAddressRequest, opts ...grpc.CallOption) (*RemoveAddressResponse, error) + CreateLabel(ctx context.Context, in *CreateLabelRequest, opts ...grpc.CallOption) (*CreateLabelResponse, error) +} + +type serverClient struct { + cc grpc.ClientConnInterface +} + +func NewServerClient(cc grpc.ClientConnInterface) ServerClient { + return &serverClient{cc} +} + +func (c *serverClient) GetInfo(ctx context.Context, in *GetInfoRequest, opts ...grpc.CallOption) (*GetInfoResponse, error) { + out := new(GetInfoResponse) + err := c.cc.Invoke(ctx, "/proto.Server/GetInfo", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *serverClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { + out := new(CreateUserResponse) + err := c.cc.Invoke(ctx, "/proto.Server/CreateUser", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *serverClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*RevokeUserResponse, error) { + out := new(RevokeUserResponse) + err := c.cc.Invoke(ctx, "/proto.Server/RevokeUser", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *serverClient) CreateAddress(ctx context.Context, in *CreateAddressRequest, opts ...grpc.CallOption) (*CreateAddressResponse, error) { + out := new(CreateAddressResponse) + err := c.cc.Invoke(ctx, "/proto.Server/CreateAddress", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *serverClient) RemoveAddress(ctx context.Context, in *RemoveAddressRequest, opts ...grpc.CallOption) (*RemoveAddressResponse, error) { + out := new(RemoveAddressResponse) + err := c.cc.Invoke(ctx, "/proto.Server/RemoveAddress", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *serverClient) CreateLabel(ctx context.Context, in *CreateLabelRequest, opts ...grpc.CallOption) (*CreateLabelResponse, error) { + out := new(CreateLabelResponse) + err := c.cc.Invoke(ctx, "/proto.Server/CreateLabel", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ServerServer is the server API for Server service. +// All implementations must embed UnimplementedServerServer +// for forward compatibility +type ServerServer interface { + GetInfo(context.Context, *GetInfoRequest) (*GetInfoResponse, error) + CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) + RevokeUser(context.Context, *RevokeUserRequest) (*RevokeUserResponse, error) + CreateAddress(context.Context, *CreateAddressRequest) (*CreateAddressResponse, error) + RemoveAddress(context.Context, *RemoveAddressRequest) (*RemoveAddressResponse, error) + CreateLabel(context.Context, *CreateLabelRequest) (*CreateLabelResponse, error) + mustEmbedUnimplementedServerServer() +} + +// UnimplementedServerServer must be embedded to have forward compatible implementations. +type UnimplementedServerServer struct { +} + +func (UnimplementedServerServer) GetInfo(context.Context, *GetInfoRequest) (*GetInfoResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetInfo not implemented") +} +func (UnimplementedServerServer) CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateUser not implemented") +} +func (UnimplementedServerServer) RevokeUser(context.Context, *RevokeUserRequest) (*RevokeUserResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RevokeUser not implemented") +} +func (UnimplementedServerServer) CreateAddress(context.Context, *CreateAddressRequest) (*CreateAddressResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateAddress not implemented") +} +func (UnimplementedServerServer) RemoveAddress(context.Context, *RemoveAddressRequest) (*RemoveAddressResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method RemoveAddress not implemented") +} +func (UnimplementedServerServer) CreateLabel(context.Context, *CreateLabelRequest) (*CreateLabelResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CreateLabel not implemented") +} +func (UnimplementedServerServer) mustEmbedUnimplementedServerServer() {} + +// UnsafeServerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ServerServer will +// result in compilation errors. +type UnsafeServerServer interface { + mustEmbedUnimplementedServerServer() +} + +func RegisterServerServer(s grpc.ServiceRegistrar, srv ServerServer) { + s.RegisterService(&Server_ServiceDesc, srv) +} + +func _Server_GetInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetInfoRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServerServer).GetInfo(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.Server/GetInfo", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServerServer).GetInfo(ctx, req.(*GetInfoRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Server_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServerServer).CreateUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.Server/CreateUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServerServer).CreateUser(ctx, req.(*CreateUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Server_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServerServer).RevokeUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.Server/RevokeUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServerServer).RevokeUser(ctx, req.(*RevokeUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Server_CreateAddress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateAddressRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServerServer).CreateAddress(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.Server/CreateAddress", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServerServer).CreateAddress(ctx, req.(*CreateAddressRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Server_RemoveAddress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RemoveAddressRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServerServer).RemoveAddress(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.Server/RemoveAddress", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServerServer).RemoveAddress(ctx, req.(*RemoveAddressRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Server_CreateLabel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CreateLabelRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ServerServer).CreateLabel(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/proto.Server/CreateLabel", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ServerServer).CreateLabel(ctx, req.(*CreateLabelRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Server_ServiceDesc is the grpc.ServiceDesc for Server service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Server_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "proto.Server", + HandlerType: (*ServerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetInfo", + Handler: _Server_GetInfo_Handler, + }, + { + MethodName: "CreateUser", + Handler: _Server_CreateUser_Handler, + }, + { + MethodName: "RevokeUser", + Handler: _Server_RevokeUser_Handler, + }, + { + MethodName: "CreateAddress", + Handler: _Server_CreateAddress_Handler, + }, + { + MethodName: "RemoveAddress", + Handler: _Server_RemoveAddress_Handler, + }, + { + MethodName: "CreateLabel", + Handler: _Server_CreateLabel_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "server.proto", +} diff --git a/server/proxy.go b/server/proxy.go new file mode 100644 index 0000000..89e8ed1 --- /dev/null +++ b/server/proxy.go @@ -0,0 +1,238 @@ +package server + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "io" + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "github.com/ProtonMail/go-proton-api" + "github.com/gin-gonic/gin" +) + +func newProxy(proxyOrigin, base, path string) http.HandlerFunc { + origin, err := url.Parse(proxyOrigin) + if err != nil { + panic(err) + } + + return (&httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = origin.Scheme + req.URL.Host = origin.Host + req.URL.Path = origin.Path + strings.TrimPrefix(path, base) + req.Host = origin.Host + }, + + Transport: proton.InsecureTransport(), + }).ServeHTTP +} + +func (s *Server) handleProxy(base string) gin.HandlerFunc { + return func(c *gin.Context) { + proxy := newProxyServer(s.proxyOrigin, base) + + proxy.handle("/", s.handleProxyAll) + + if s.authCacher != nil { + proxy.handle("/core/v4/auth", s.handleProxyAuth) + proxy.handle("/core/v4/auth/info", s.handleProxyAuthInfo) + } + + proxy.ServeHTTP(c.Writer, c.Request) + } +} + +func (s *Server) handleProxyAll(proxier func(string) HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if _, err := proxier(r.URL.Path)(w, r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } +} + +func (s *Server) handleProxyAuth(proxier func(string) HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + s.handleProxyAuthPost(w, r, proxier(r.URL.Path)) + + case http.MethodDelete: + s.handleProxyAuthDelete(w, r, proxier(r.URL.Path)) + } + } +} + +func (s *Server) handleProxyAuthPost(w http.ResponseWriter, r *http.Request, proxier HandlerFunc) { + req, err := readFromBody[proton.AuthReq](r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if info, ok := s.authCacher.GetAuth(req.Username); ok { + if err := writeBody(w, info); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } else { + b, err := proxier(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + res, err := readFrom[proton.Auth](b) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.authCacher.SetAuth(req.Username, res) + } +} + +func (s *Server) handleProxyAuthDelete(w http.ResponseWriter, r *http.Request, proxier HandlerFunc) { + // When caching, we don't need to do anything here. +} + +func (s *Server) handleProxyAuthInfo(proxier func(string) HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + req, err := readFromBody[proton.AuthInfoReq](r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if info, ok := s.authCacher.GetAuthInfo(req.Username); ok { + if err := writeBody(w, info); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } else { + b, err := proxier(r.URL.Path)(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + res, err := readFrom[proton.AuthInfo](b) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + s.authCacher.SetAuthInfo(req.Username, res) + } + } +} + +type HandlerFunc func(http.ResponseWriter, *http.Request) ([]byte, error) + +type proxyServer struct { + mux *http.ServeMux + + origin, base string +} + +func newProxyServer(origin, base string) *proxyServer { + return &proxyServer{ + mux: http.NewServeMux(), + origin: origin, + base: base, + } +} + +func (s *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} + +func (s *proxyServer) handle(path string, h func(func(string) HandlerFunc) http.HandlerFunc) { + s.mux.Handle(s.base+path, h(func(path string) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) ([]byte, error) { + buf := new(bytes.Buffer) + + // Call the proxy, capturing whatever data it writes. + newProxy(s.origin, s.base, path)(&writerWrapper{w, buf}, r) + + // If there is a gzip header entry, decode it. + if strings.Contains(w.Header().Get("Content-Encoding"), "gzip") { + return gzipDecode(buf.Bytes()) + } + + // Otherwise, return the original written data. + return buf.Bytes(), nil + } + })) +} + +type writerWrapper struct { + http.ResponseWriter + + buf *bytes.Buffer +} + +func (w *writerWrapper) Write(b []byte) (int, error) { + if _, err := w.buf.Write(b); err != nil { + return 0, err + } + + return w.ResponseWriter.Write(b) +} + +func readFrom[T any](b []byte) (T, error) { + var v T + + if err := json.Unmarshal(b, &v); err != nil { + return *new(T), err + } + + return v, nil +} + +func readFromBody[T any](r *http.Request) (T, error) { + b, err := io.ReadAll(r.Body) + if err != nil { + return *new(T), err + } + defer r.Body.Close() + + v, err := readFrom[T](b) + if err != nil { + return *new(T), err + } + + r.Body = io.NopCloser(bytes.NewReader(b)) + + return v, nil +} + +func writeBody[T any](w http.ResponseWriter, v T) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + w.Header().Set("Content-Type", "application/json") + + if _, err := w.Write(b); err != nil { + return err + } + + return nil +} + +func gzipDecode(b []byte) ([]byte, error) { + r, err := gzip.NewReader(bytes.NewReader(b)) + if err != nil { + return nil, err + } + defer r.Close() + + return io.ReadAll(r) +} diff --git a/server/reports.go b/server/reports.go new file mode 100644 index 0000000..ab47c4a --- /dev/null +++ b/server/reports.go @@ -0,0 +1,16 @@ +package server + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func (s *Server) handlePostReportBug() gin.HandlerFunc { + return func(c *gin.Context) { + if _, err := c.MultipartForm(); err != nil { + _ = c.AbortWithError(http.StatusBadRequest, err) + return + } + } +} diff --git a/server/router.go b/server/router.go new file mode 100644 index 0000000..20d9006 --- /dev/null +++ b/server/router.go @@ -0,0 +1,284 @@ +package server + +import ( + "bytes" + "errors" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/go-proton-api" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +func initRouter(s *Server) { + s.r.Use( + s.requireValidAppVersion(), + s.setSessionCookie(), + ) + + if core := s.r.Group("/core/v4"); core != nil { + // These routes are not protected by authentication. + if auth := core.Group("/auth"); auth != nil { + auth.POST("", s.handlePostAuth()) + auth.POST("/info", s.handlePostAuthInfo()) + auth.POST("/refresh", s.handlePostAuthRefresh()) + } + + // Reporting a bug is also possible without authentication. + if reports := core.Group("/reports"); reports != nil { + reports.POST("/bug", s.handlePostReportBug()) + } + + // These routes require auth. + if core := core.Group("", s.requireAuth()); core != nil { + if auth := core.Group("/auth"); auth != nil { + auth.DELETE("", s.handleDeleteAuth()) + } + + if users := core.Group("/users"); users != nil { + users.GET("", s.handleGetUsers()) + } + + if addresses := core.Group("/addresses"); addresses != nil { + addresses.GET("", s.handleGetAddresses()) + addresses.GET("/:addressID", s.handleGetAddress()) + addresses.PUT("/order", s.handlePutAddressesOrder()) + } + + if labels := core.Group("/labels"); labels != nil { + labels.GET("", s.handleGetMailLabels()) + labels.POST("", s.handlePostMailLabels()) + labels.PUT("/:labelID", s.handlePutMailLabel()) + labels.DELETE("/:labelID", s.handleDeleteMailLabel()) + } + + if keys := core.Group("/keys"); keys != nil { + keys.GET("", s.handleGetKeys()) + keys.GET("/salts", s.handleGetKeySalts()) + } + + if events := core.Group("/events"); events != nil { + events.GET("/:eventID", s.handleGetEvents()) + events.GET("/latest", s.handleGetEventsLatest()) + } + } + } + + // All mail routes need authentication. + if mail := s.r.Group("/mail/v4", s.requireAuth()); mail != nil { + if settings := mail.Group("/settings"); settings != nil { + settings.GET("", s.handleGetMailSettings()) + } + + if messages := mail.Group("/messages"); messages != nil { + messages.GET("", s.handleGetMailMessages()) + messages.POST("", s.handlePostMailMessages()) + messages.GET("/ids", s.handleGetMailMessageIDs()) + messages.GET("/:messageID", s.handleGetMailMessage()) + messages.POST("/:messageID", s.handlePostMailMessage()) + messages.PUT("/read", s.handlePutMailMessagesRead()) + messages.PUT("/unread", s.handlePutMailMessagesUnread()) + messages.PUT("/label", s.handlePutMailMessagesLabel()) + messages.PUT("/unlabel", s.handlePutMailMessagesUnlabel()) + messages.POST("/import", s.handlePutMailMessagesImport()) + messages.PUT("/delete", s.handleDeleteMailMessages()) + } + + if attachments := mail.Group("/attachments"); attachments != nil { + attachments.POST("", s.handlePostMailAttachments()) + attachments.GET(":attachID", s.handleGetMailAttachment()) + } + } + + // All contacts routes need authentication. + if contacts := s.r.Group("/contacts/v4", s.requireAuth()); contacts != nil { + contacts.GET("/emails", s.handleGetContactsEmails()) + } + + // All auth routes need authentication. + if auth := s.r.Group("/auth/v4", s.requireAuth()); auth != nil { + auth.GET("/sessions", s.handleGetAuthSessions()) + auth.DELETE("/sessions", s.handleDeleteAuthSessions()) + auth.DELETE("/sessions/:authUID", s.handleDeleteAuthSession()) + } + + // Test routes don't need authentication. + if tests := s.r.Group("/tests"); tests != nil { + tests.GET("/ping", s.handleGetPing()) + } + + // Proxy any calls to the upstream server. + if proxy := s.r.Group("/proxy"); proxy != nil { + proxy.Any("/*path", s.handleProxy(proxy.BasePath())) + } +} + +func (s *Server) requireValidAppVersion() gin.HandlerFunc { + return func(c *gin.Context) { + appVersion := c.Request.Header.Get("x-pm-appversion") + + if appVersion == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, proton.Error{ + Code: proton.AppVersionMissingCode, + Message: "Missing x-pm-appversion header", + }) + } else if ok := s.validateAppVersion(appVersion); !ok { + c.AbortWithStatusJSON(http.StatusBadRequest, proton.Error{ + Code: proton.AppVersionBadCode, + Message: "This version of the app is no longer supported, please update to continue using the app", + }) + } + } +} + +func (s *Server) setSessionCookie() gin.HandlerFunc { + return func(c *gin.Context) { + url, err := url.Parse(s.s.URL) + if err != nil { + panic(err) + } + + host, _, err := net.SplitHostPort(url.Host) + if err != nil { + panic(err) + } + + if cookie, err := c.Request.Cookie("Session-Id"); errors.Is(err, http.ErrNoCookie) { + c.SetCookie("Session-Id", uuid.NewString(), int(90*24*time.Hour.Seconds()), "/", host, true, true) + } else { + c.SetCookie("Session-Id", cookie.Value, int(90*24*time.Hour.Seconds()), "/", host, true, true) + } + } +} + +func (s *Server) logCalls() gin.HandlerFunc { + return func(c *gin.Context) { + req, err := io.ReadAll(c.Request.Body) + if err != nil { + panic(err) + } else { + c.Request.Body = io.NopCloser(bytes.NewReader(req)) + } + + res, err := newBodyWriter(c.Writer) + if err != nil { + panic(err) + } else { + c.Writer = res + } + + c.Next() + + s.callWatchersLock.RLock() + defer s.callWatchersLock.RUnlock() + + for _, call := range s.callWatchers { + if call.isWatching(c.Request.URL.Path) { + call.publish(Call{ + URL: c.Request.URL, + Method: c.Request.Method, + Status: c.Writer.Status(), + + RequestHeader: c.Request.Header, + RequestBody: req, + + ResponseHeader: c.Writer.Header(), + ResponseBody: res.bytes(), + }) + } + } + } +} + +func (s *Server) handleOffline() gin.HandlerFunc { + return func(c *gin.Context) { + if s.offline { + c.AbortWithStatus(http.StatusServiceUnavailable) + return + } + } +} + +func (s *Server) requireAuth() gin.HandlerFunc { + return func(c *gin.Context) { + authUID := c.Request.Header.Get("x-pm-uid") + if authUID == "" { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + auth := c.Request.Header.Get("Authorization") + if auth == "" { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + userID, err := s.b.VerifyAuth(authUID, strings.Split(auth, " ")[1]) + if err != nil { + c.AbortWithStatus(http.StatusUnauthorized) + return + } + + c.Set("UserID", userID) + + c.Set("AuthUID", authUID) + } +} + +func (s *Server) validateAppVersion(appVersion string) bool { + if s.minAppVersion == nil { + return true + } + + split := strings.Split(appVersion, "_") + + if len(split) != 2 { + return false + } + + version, err := semver.NewVersion(split[1]) + if err != nil { + return false + } + + if version.LessThan(s.minAppVersion) { + return false + } + + return true +} + +type bodyWriter struct { + gin.ResponseWriter + buf *bytes.Buffer +} + +func newBodyWriter(w gin.ResponseWriter) (*bodyWriter, error) { + if w == nil { + return nil, errors.New("response writer is nil") + } + + return &bodyWriter{ + ResponseWriter: w, + + buf: &bytes.Buffer{}, + }, nil +} + +func (w bodyWriter) Write(b []byte) (int, error) { + if n, err := w.buf.Write(b); err != nil { + return n, err + } + + return w.ResponseWriter.Write(b) +} + +func (w bodyWriter) bytes() []byte { + return w.buf.Bytes() +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..1e0e119 --- /dev/null +++ b/server/server.go @@ -0,0 +1,188 @@ +package server + +import ( + "net/http/httptest" + "sync" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server/backend" + "github.com/bradenaw/juniper/xslices" + "github.com/gin-gonic/gin" +) + +type AuthCacher interface { + GetAuthInfo(username string) (proton.AuthInfo, bool) + SetAuthInfo(username string, info proton.AuthInfo) + GetAuth(username string) (proton.Auth, bool) + SetAuth(username string, auth proton.Auth) +} + +type Server struct { + // r is the gin router. + r *gin.Engine + + // s is the underlying server. + s *httptest.Server + + // b is the server backend, which manages accounts, messages, attachments, etc. + b *backend.Backend + + // callWatchers records callWatchers received by the server. + callWatchers []callWatcher + callWatchersLock sync.RWMutex + + // MinAppVersion is the minimum app version that the server will accept. + minAppVersion *semver.Version + + // proxyOrigin is the URL of the origin server when the server is a proxy. + proxyOrigin string + + // authCacher can optionally be set to cache proxied auth calls. + authCacher AuthCacher + + // offline is whether to pretend the server is offline and return 5xx errors. + offline bool +} + +func New(opts ...Option) *Server { + builder := newServerBuilder() + + for _, opt := range opts { + opt.config(builder) + } + + return builder.build() +} + +func (s *Server) GetHostURL() string { + return s.s.URL +} + +// GetProxyURL returns the API root to make calls to which should be proxied. +func (s *Server) GetProxyURL() string { + return s.s.URL + "/proxy" +} + +func (s *Server) AddCallWatcher(fn func(Call), paths ...string) { + s.callWatchersLock.Lock() + defer s.callWatchersLock.Unlock() + + s.callWatchers = append(s.callWatchers, newCallWatcher(fn, paths...)) +} + +func (s *Server) CreateUser(username, email string, password []byte) (string, string, error) { + userID, err := s.b.CreateUser(username, password) + if err != nil { + return "", "", err + } + + addrID, err := s.b.CreateAddress(userID, email, password) + if err != nil { + return "", "", err + } + + return userID, addrID, nil +} + +func (s *Server) RemoveUser(userID string) error { + return s.b.RemoveUser(userID) +} + +func (s *Server) RefreshUser(userID string, refresh proton.RefreshFlag) error { + return s.b.RefreshUser(userID, refresh) +} + +func (s *Server) GetUserKeyIDs(userID string) ([]string, error) { + user, err := s.b.GetUser(userID) + if err != nil { + return nil, err + } + + return xslices.Map(user.Keys, func(key proton.Key) string { + return key.ID + }), nil +} + +func (s *Server) CreateUserKey(userID string, password []byte) error { + return s.b.CreateUserKey(userID, password) +} + +func (s *Server) RemoveUserKey(userID, keyID string) error { + return s.b.RemoveUserKey(userID, keyID) +} + +func (s *Server) CreateAddress(userID, email string, password []byte) (string, error) { + return s.b.CreateAddress(userID, email, password) +} + +func (s *Server) RemoveAddress(userID, addrID string) error { + return s.b.RemoveAddress(userID, addrID) +} + +func (s *Server) CreateAddressKey(userID, addrID string, password []byte) error { + return s.b.CreateAddressKey(userID, addrID, password) +} + +func (s *Server) RemoveAddressKey(userID, addrID, keyID string) error { + return s.b.RemoveAddressKey(userID, addrID, keyID) +} + +func (s *Server) CreateLabel(userID, name, parentID string, labelType proton.LabelType) (string, error) { + label, err := s.b.CreateLabel(userID, name, parentID, labelType) + if err != nil { + return "", err + } + + return label.ID, nil +} + +func (s *Server) GetLabels(userID string) ([]proton.Label, error) { + return s.b.GetLabels(userID) +} + +func (s *Server) LabelMessage(userID, msgID, labelID string) error { + return s.b.LabelMessages(userID, labelID, msgID) +} + +func (s *Server) UnlabelMessage(userID, msgID, labelID string) error { + return s.b.UnlabelMessages(userID, labelID, msgID) +} + +func (s *Server) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) error { + _, err := s.b.UpdateDraft(userID, draftID, changes) + + return err +} + +func (s *Server) SetAuthLife(authLife time.Duration) { + s.b.SetAuthLife(authLife) +} + +func (s *Server) SetMinAppVersion(minAppVersion *semver.Version) { + s.minAppVersion = minAppVersion +} + +func (s *Server) SetOffline(offline bool) { + s.offline = offline +} + +func (s *Server) RevokeUser(userID string) error { + sessions, err := s.b.GetSessions(userID) + if err != nil { + return err + } + + for _, session := range sessions { + if err := s.b.DeleteSession(userID, session.UID); err != nil { + return err + } + } + + return nil +} + +func (s *Server) Close() { + s.s.Close() +} diff --git a/server/server_builder.go b/server/server_builder.go new file mode 100644 index 0000000..8beaed5 --- /dev/null +++ b/server/server_builder.go @@ -0,0 +1,127 @@ +package server + +import ( + "io" + "net/http/httptest" + "os" + "time" + + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/go-proton-api/server/backend" + "github.com/gin-gonic/gin" +) + +type serverBuilder struct { + withTLS bool + logger io.Writer + origin string + cacher AuthCacher +} + +func newServerBuilder() *serverBuilder { + var logger io.Writer + + if os.Getenv("GO_PROTON_API_SERVER_LOGGER_ENABLED") != "" { + logger = gin.DefaultWriter + } else { + logger = io.Discard + } + + return &serverBuilder{ + withTLS: true, + logger: logger, + origin: proton.DefaultHostURL, + } +} + +func (builder *serverBuilder) build() *Server { + gin.SetMode(gin.ReleaseMode) + + s := &Server{ + r: gin.New(), + b: backend.New(time.Hour), + + proxyOrigin: builder.origin, + authCacher: builder.cacher, + } + + if builder.withTLS { + s.s = httptest.NewTLSServer(s.r) + } else { + s.s = httptest.NewServer(s.r) + } + + s.r.Use( + gin.LoggerWithConfig(gin.LoggerConfig{Output: builder.logger}), + gin.Recovery(), + s.logCalls(), + s.handleOffline(), + ) + + initRouter(s) + + return s +} + +// Option represents a type that can be used to configure the server. +type Option interface { + config(*serverBuilder) +} + +// WithTLS controls whether the server should serve over TLS. +func WithTLS(tls bool) Option { + return &withTLS{ + withTLS: tls, + } +} + +type withTLS struct { + withTLS bool +} + +func (opt withTLS) config(builder *serverBuilder) { + builder.withTLS = opt.withTLS +} + +// WithLogger controls where Gin logs to. +func WithLogger(logger io.Writer) Option { + return &withLogger{ + logger: logger, + } +} + +type withLogger struct { + logger io.Writer +} + +func (opt withLogger) config(builder *serverBuilder) { + builder.logger = opt.logger +} + +func WithProxyOrigin(origin string) Option { + return &withProxyOrigin{ + origin: origin, + } +} + +type withProxyOrigin struct { + origin string +} + +func (opt withProxyOrigin) config(builder *serverBuilder) { + builder.origin = opt.origin +} + +func WithAuthCacher(cacher AuthCacher) Option { + return &withAuthCache{ + cacher: cacher, + } +} + +type withAuthCache struct { + cacher AuthCacher +} + +func (opt withAuthCache) config(builder *serverBuilder) { + builder.cacher = opt.cacher +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..984060e --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,1712 @@ +package server + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/mail" + "net/url" + "os" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Masterminds/semver/v3" + "github.com/ProtonMail/go-proton-api" + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/iterator" + "github.com/bradenaw/juniper/stream" + "github.com/bradenaw/juniper/xslices" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" +) + +func TestServer(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + user, err := c.GetUser(ctx) + require.NoError(t, err) + require.Equal(t, "user", user.Name) + require.Equal(t, "email@pm.me", user.Email) + + // Logout from the test API. + require.NoError(t, c.AuthDelete(ctx)) + + // Future requests should fail. + require.Error(t, c.AuthDelete(ctx)) + }) + }) +} + +func TestServerMulti(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + _, _, err := s.CreateUser("user", "email@pm.me", []byte("pass")) + require.NoError(t, err) + + // Create one client. + c1, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer c1.Close() + + // Create another client. + c2, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer c2.Close() + + // Both clients should be able to make requests. + must(c1.GetUser(ctx)) + must(c2.GetUser(ctx)) + + // Logout the first client; it should no longer be able to make requests. + require.NoError(t, c1.AuthDelete(ctx)) + require.Panics(t, func() { must(c1.GetUser(ctx)) }) + + // The second client should still be able to make requests. + must(c2.GetUser(ctx)) + }) +} + +func TestServer_Ping(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, _ *proton.Manager) { + ctl := proton.NewNetCtl() + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + ) + + var status proton.Status + + m.AddStatusObserver(func(s proton.Status) { + status = s + }) + + // When the network goes down, ping should fail. + ctl.Disable() + require.Error(t, m.Ping(ctx)) + require.Equal(t, proton.StatusDown, status) + + // When the network goes up, ping should succeed. + ctl.Enable() + require.NoError(t, m.Ping(ctx)) + require.Equal(t, proton.StatusUp, status) + + // When the API is down, ping should still succeed if the API is at least reachable. + s.SetOffline(true) + require.NoError(t, m.Ping(ctx)) + require.Equal(t, proton.StatusUp, status) + + // When the API is down, ping should fail if the API cannot be reached. + ctl.Disable() + require.Error(t, m.Ping(ctx)) + require.Equal(t, proton.StatusDown, status) + + // When the network goes up, ping should succeed, even if the API is down. + ctl.Enable() + require.NoError(t, m.Ping(ctx)) + require.Equal(t, proton.StatusUp, status) + + // When the API comes back alive, ping should succeed. + s.SetOffline(false) + require.NoError(t, m.Ping(ctx)) + require.Equal(t, proton.StatusUp, status) + }) +} + +func TestServer_Bool(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1, func([]string) { + metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) + require.NoError(t, err) + + // By default the message is unread. + require.True(t, bool(must(c.GetMessage(ctx, metadata[0].ID)).Unread)) + + // Mark the message as read. + require.NoError(t, c.MarkMessagesRead(ctx, metadata[0].ID)) + + // Now the message is read. + require.False(t, bool(must(c.GetMessage(ctx, metadata[0].ID)).Unread)) + }) + }) + }) +} + +func TestServer_Messages(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { + // Get the messages. + metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) + require.NoError(t, err) + + // The messages should be the ones we created. + require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { + return metadata.ID + })) + + // The messages should be in All Mail and should be unread. + for _, message := range metadata { + require.True(t, bool(message.Unread)) + require.Equal(t, []string{proton.AllMailLabel}, message.LabelIDs) + } + + // Mark the first three messages as read and put them in archive. + require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[0], messageIDs[1], messageIDs[2])) + require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[0], messageIDs[1], messageIDs[2]}, proton.ArchiveLabel)) + + // They should now be read. + require.False(t, bool(must(c.GetMessage(ctx, messageIDs[0])).Unread)) + require.False(t, bool(must(c.GetMessage(ctx, messageIDs[1])).Unread)) + require.False(t, bool(must(c.GetMessage(ctx, messageIDs[2])).Unread)) + + // They should now be in archive. + require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[0])).LabelIDs) + require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[1])).LabelIDs) + require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[2])).LabelIDs) + + // Put them in inbox. + require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[0], messageIDs[1], messageIDs[2]}, proton.ArchiveLabel)) + }) + }) + }) +} + +func TestServer_MessageFilter(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { + // Get the messages. + metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) + require.NoError(t, err) + + // The messages should be the ones we created. + require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { + return metadata.ID + })) + + // Get metadata for just the first three messages. + partial, err := c.GetMessageMetadata(ctx, proton.MessageFilter{ + ID: []string{ + metadata[0].ID, + metadata[1].ID, + metadata[2].ID, + }, + }) + require.NoError(t, err) + + // The messages should be just the first three. + require.Equal(t, metadata[:3], partial) + }) + }) + }) +} + +func TestServer_MessageIDs(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 10000, func(wantMessageIDs []string) { + allMessageIDs, err := c.GetMessageIDs(ctx, "") + require.NoError(t, err) + require.ElementsMatch(t, wantMessageIDs, allMessageIDs) + + halfMessageIDs, err := c.GetMessageIDs(ctx, allMessageIDs[len(allMessageIDs)/2]) + require.NoError(t, err) + require.ElementsMatch(t, allMessageIDs[len(allMessageIDs)/2+1:], halfMessageIDs) + }) + }) + }) +} + +func TestServer_MessagesDelete(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { + // Get the messages. + metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) + require.NoError(t, err) + + // The messages should be the ones we created. + require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { + return metadata.ID + })) + + // Delete half the messages. + require.NoError(t, c.DeleteMessage(ctx, messageIDs[0:500]...)) + + // Get the remaining messages. + remaining, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) + require.NoError(t, err) + + // The remaining messages should be the ones we didn't delete. + require.ElementsMatch(t, messageIDs[500:], xslices.Map(remaining, func(metadata proton.MessageMetadata) string { + return metadata.ID + })) + }) + }) + }) +} + +func TestServer_MessagesDeleteAfterUpdate(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { + // Get the initial event ID. + eventID, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + + // Put half the messages in archive. + require.NoError(t, c.LabelMessages(ctx, messageIDs[0:500], proton.ArchiveLabel)) + + // Delete half the messages. + require.NoError(t, c.DeleteMessage(ctx, messageIDs[0:500]...)) + + // Get the event reflecting this change. + event, err := c.GetEvent(ctx, eventID) + require.NoError(t, err) + + // The event should have the correct number of message events. + require.Len(t, event.Messages, 500) + + // All the events should be delete events. + for _, message := range event.Messages { + require.Equal(t, proton.EventDelete, message.Action) + } + }) + }) + }) +} + +func TestServer_Events(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 3, func(messageIDs []string) { + // Get the latest event ID to stream from. + fromEventID, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + + // Begin collecting events. + eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) + + // Mark a message as read. + require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[0])) + + // The message should eventually be read. + require.Eventually(t, func() bool { + event := <-eventCh + + if len(event.Messages) != 1 { + return false + } + + if event.Messages[0].ID != messageIDs[0] { + return false + } + + return !bool(event.Messages[0].Message.Unread) + }, 5*time.Second, time.Millisecond*100) + + // Add another message to archive. + require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[1]}, proton.ArchiveLabel)) + + // The message should eventually be in archive and all mail. + require.Eventually(t, func() bool { + event := <-eventCh + + if len(event.Messages) != 1 { + return false + } + + if event.Messages[0].ID != messageIDs[1] { + return false + } + + return elementsMatch([]string{proton.ArchiveLabel, proton.AllMailLabel}, event.Messages[0].Message.LabelIDs) + }, 5*time.Second, time.Millisecond*100) + + // Perform a sequence of actions on the same message. + require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[2]}, proton.TrashLabel)) + require.NoError(t, c.UnlabelMessages(ctx, []string{messageIDs[2]}, proton.TrashLabel)) + require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[2])) + require.NoError(t, c.MarkMessagesUnread(ctx, messageIDs[2])) + + // The final state of the message should be correct. + require.Eventually(t, func() bool { + event := <-eventCh + + if len(event.Messages) != 1 { + return false + } + + if event.Messages[0].ID != messageIDs[2] { + return false + } + + return bool(event.Messages[0].Message.Unread) && elementsMatch([]string{proton.AllMailLabel}, event.Messages[0].Message.LabelIDs) + }, 5*time.Second, time.Millisecond*100) + + // No more events should be sent. + select { + case <-eventCh: + t.Fatal("unexpected event") + + default: + // .... + } + }) + }) + }) +} + +func TestServer_Events_Multi(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + for i := 0; i < 10; i++ { + withUser(ctx, t, s, m, fmt.Sprintf("user%v", i), fmt.Sprintf("email%v@pm.me", i), "pass", func(c *proton.Client) { + latest, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + + // Fetching latest again should return the same event ID. + latestAgain, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + require.Equal(t, latest, latestAgain) + + event, err := c.GetEvent(ctx, latest) + require.NoError(t, err) + + // The event should be empty. + require.Equal(t, proton.Event{EventID: event.EventID}, event) + + // After fetching an empty event, its ID should still be the latest. + eventAgain, err := c.GetEvent(ctx, event.EventID) + require.NoError(t, err) + require.Equal(t, eventAgain.EventID, event.EventID) + }) + } + }) +} + +func TestServer_Events_Refresh(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + user, err := c.GetUser(ctx) + require.NoError(t, err) + + // Get the latest event ID to stream from. + fromEventID, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + + // Refresh the user's mail. + require.NoError(t, s.RefreshUser(user.ID, proton.RefreshMail)) + + // Begin collecting events. + eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) + + // The user should eventually be refreshed. + require.Eventually(t, func() bool { + return (<-eventCh).Refresh&proton.RefreshMail != 0 + }, 5*time.Second, time.Millisecond*100) + }) + }) +} + +func TestServer_RevokeUser(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + user, err := c.GetUser(ctx) + require.NoError(t, err) + require.Equal(t, "user", user.Name) + require.Equal(t, "email@pm.me", user.Email) + + // Revoke the user's auth. + require.NoError(t, s.RevokeUser(user.ID)) + + // Future requests should fail. + require.Error(t, c.AuthDelete(ctx)) + }) + }) +} + +func TestServer_Calls(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + var calls []Call + + // Watch calls that are made. + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + // Get the user. + _, err := c.GetUser(ctx) + require.NoError(t, err) + + // Logout the user. + require.NoError(t, c.AuthDelete(ctx)) + + // The user call should be correct. + userCall := calls[0] + require.Equal(t, "/core/v4/users", userCall.URL.Path) + require.Equal(t, "GET", userCall.Method) + require.Equal(t, http.StatusOK, userCall.Status) + + // The logout call should be correct. + logoutCall := calls[1] + require.Equal(t, "/core/v4/auth", logoutCall.URL.Path) + require.Equal(t, "DELETE", logoutCall.Method) + require.Equal(t, http.StatusOK, logoutCall.Status) + }) + }) +} + +func TestServer_Calls_Status(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + var calls []Call + + // Watch calls that are made. + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + // Make a bad call. + _, err := c.GetMessage(ctx, "no such message ID") + require.Error(t, err) + + // The user call should have error status. + require.Equal(t, http.StatusUnprocessableEntity, calls[0].Status) + }) + }) +} + +func TestServer_Calls_Request(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + var calls []Call + + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(*proton.Client) { + require.Equal( + t, + calls[0].RequestBody, + must(json.Marshal(proton.AuthInfoReq{Username: "user"})), + ) + }) + }) +} + +func TestServer_Calls_Response(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + var calls []Call + + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + salts, err := c.GetSalts(ctx) + require.NoError(t, err) + + require.Equal( + t, + calls[len(calls)-1].ResponseBody, + must(json.Marshal(struct{ KeySalts []proton.Salt }{salts})), + ) + }) + }) +} + +func TestServer_Calls_Cookies(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + var calls []Call + + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(*proton.Client) { + // The header in the first call's response should set the Session-Id cookie. + resHeader := (&http.Response{Header: calls[len(calls)-2].ResponseHeader}) + require.Len(t, resHeader.Cookies(), 1) + require.Equal(t, "Session-Id", resHeader.Cookies()[0].Name) + require.NotEmpty(t, resHeader.Cookies()[0].Value) + + // The cookie should be sent in the next call. + reqHeader := (&http.Request{Header: calls[len(calls)-1].RequestHeader}) + require.Len(t, reqHeader.Cookies(), 1) + require.Equal(t, "Session-Id", reqHeader.Cookies()[0].Name) + require.NotEmpty(t, reqHeader.Cookies()[0].Value) + + // The cookie should be the same. + require.Equal(t, resHeader.Cookies()[0].Value, reqHeader.Cookies()[0].Value) + }) + }) +} + +func TestServer_Calls_Manager(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + var calls []Call + + // Watch calls that are made. + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + // Make a non-user request. + require.NoError(t, m.ReportBug(ctx, proton.ReportBugReq{})) + + // The call should be correct. + reportCall := calls[0] + require.Equal(t, "/core/v4/reports/bug", reportCall.URL.Path) + require.Equal(t, "POST", reportCall.Method) + require.Equal(t, http.StatusOK, reportCall.Status) + }) +} + +func TestServer_CreateMessage(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + addresses, err := c.GetAddresses(ctx) + require.NoError(t, err) + + draft, err := c.CreateDraft(ctx, proton.CreateDraftReq{ + Message: proton.DraftTemplate{ + Subject: "My subject", + Sender: &mail.Address{Address: addresses[0].Email}, + ToList: []*mail.Address{{Address: "recipient@pm.me"}}, + }, + }) + require.NoError(t, err) + + require.Equal(t, addresses[0].ID, draft.AddressID) + require.Equal(t, "My subject", draft.Subject) + require.Equal(t, &mail.Address{Address: "email@pm.me"}, draft.Sender) + }) + }) +} + +func TestServer_UpdateDraft(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + addresses, err := c.GetAddresses(ctx) + require.NoError(t, err) + + draft, err := c.CreateDraft(ctx, proton.CreateDraftReq{ + Message: proton.DraftTemplate{ + Subject: "My subject", + Sender: &mail.Address{Address: addresses[0].Email}, + ToList: []*mail.Address{{Address: "recipient@pm.me"}}, + }, + }) + require.NoError(t, err) + + require.Equal(t, addresses[0].ID, draft.AddressID) + require.Equal(t, "My subject", draft.Subject) + require.Equal(t, &mail.Address{Address: "email@pm.me"}, draft.Sender) + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + fromEventID, err := c.GetLatestEventID(ctx) + require.NoError(t, err) + + eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) + + // Draft updated on server side. + _, err = s.b.UpdateDraft(user.ID, draft.ID, proton.DraftTemplate{ + Subject: "Edited subject", + ToList: []*mail.Address{{Address: "edited@pm.me"}}, + Body: "Edited body", + }) + require.NoError(t, err) + + var updated *proton.MessageMetadata + + require.Eventually(t, func() bool { + event := <-eventCh + + if len(event.Messages) != 1 { + return false + } + + if event.Messages[0].ID != draft.ID { + return false + } + + if event.Messages[0].Action != proton.EventUpdate { + return false + } + + updated = &event.Messages[0].Message + + return true + }, 5*time.Second, time.Millisecond*100) + + require.Equal(t, draft.ID, updated.ID) + require.Equal(t, "Edited subject", updated.Subject) + require.Equal(t, []*mail.Address{{Address: "edited@pm.me"}}, updated.ToList) + }) + }) +} + +func TestServer_SendMessage(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + addresses, err := c.GetAddresses(ctx) + require.NoError(t, err) + + draft, err := c.CreateDraft(ctx, proton.CreateDraftReq{ + Message: proton.DraftTemplate{ + Subject: "My subject", + Sender: &mail.Address{Address: addresses[0].Email}, + ToList: []*mail.Address{{Address: "recipient@pm.me"}}, + }, + }) + require.NoError(t, err) + + sent, err := c.SendDraft(ctx, draft.ID, proton.SendDraftReq{}) + require.NoError(t, err) + + require.Equal(t, draft.ID, sent.ID) + require.Equal(t, addresses[0].ID, sent.AddressID) + require.Equal(t, "My subject", sent.Subject) + require.Contains(t, sent.LabelIDs, proton.SentLabel) + }) + }) +} + +func TestServer_AuthDelete(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + require.NoError(t, c.AuthDelete(ctx)) + }) + }) +} + +func TestServer_ForceUpgrade(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := New() + defer s.Close() + + s.SetMinAppVersion(semver.MustParse("1.0.0")) + + if _, _, err := s.CreateUser("user", "email@pm.me", []byte("pass")); err != nil { + t.Fatal(err) + } + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithAppVersion("proton_0.9.0"), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + var called bool + + m.AddErrorHandler(proton.AppVersionBadCode, func() { + called = true + }) + + if _, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")); err == nil { + t.Fatal(err) + } + + require.True(t, called) +} + +func TestServer_Import(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass) + require.NoError(t, err) + + res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, 1) + require.NoError(t, err) + require.Len(t, res, 1) + require.Equal(t, proton.SuccessCode, res[0].Code) + + message, err := c.GetMessage(ctx, res[0].MessageID) + require.NoError(t, err) + + dec, err := message.Decrypt(addrKRs[message.AddressID]) + require.NoError(t, err) + require.NotEmpty(t, dec) + }) + }) +} + +func TestServer_Labels(t *testing.T) { + type add string + type rem string + + tests := []struct { + name string + flags proton.MessageFlag + actions []any + wantLabelIDs []string + wantError bool + }{ + { + name: "received flag, no actions", + flags: proton.MessageFlagReceived, + wantLabelIDs: []string{proton.AllMailLabel}, + }, + { + name: "sent flag, no actions", + flags: proton.MessageFlagSent, + wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add inbox", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.InboxLabel)}, + wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, + }, + { + name: "sent flag, add sent", + flags: proton.MessageFlagSent, + actions: []any{add(proton.SentLabel)}, + wantLabelIDs: []string{proton.SentLabel, proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add inbox then add archive", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.InboxLabel), add(proton.ArchiveLabel)}, + wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel}, + }, + { + name: "sent flag, add sent then add archive", + flags: proton.MessageFlagSent, + actions: []any{add(proton.SentLabel), add(proton.ArchiveLabel)}, + wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add inbox then remove inbox", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.InboxLabel), rem(proton.InboxLabel)}, + wantLabelIDs: []string{proton.AllMailLabel}, + }, + { + name: "sent flag, add sent then remove sent", + flags: proton.MessageFlagSent, + actions: []any{add(proton.SentLabel), rem(proton.SentLabel)}, + wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add inbox then remove archive", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.InboxLabel), rem(proton.ArchiveLabel)}, + wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, + }, + { + name: "sent flag, add sent then remove archive", + flags: proton.MessageFlagSent, + actions: []any{add(proton.SentLabel), rem(proton.ArchiveLabel)}, + wantLabelIDs: []string{proton.SentLabel, proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add inbox then remove inbox then add archive", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.InboxLabel), rem(proton.InboxLabel), add(proton.ArchiveLabel)}, + wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel}, + }, + { + name: "sent flag, add sent then remove sent then add archive", + flags: proton.MessageFlagSent, + actions: []any{add(proton.SentLabel), rem(proton.SentLabel), add(proton.ArchiveLabel)}, + wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add starred", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.StarredLabel)}, + wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel}, + }, + { + name: "sent flag, add starred", + flags: proton.MessageFlagSent, + actions: []any{add(proton.StarredLabel)}, + wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add inbox, add starred, remove inbox", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.InboxLabel), add(proton.StarredLabel), rem(proton.InboxLabel)}, + wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel}, + }, + { + name: "sent flag, add sent, add starred, remove sent", + flags: proton.MessageFlagSent, + actions: []any{add(proton.SentLabel), add(proton.StarredLabel), rem(proton.SentLabel)}, + wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel}, + }, + { + name: "received flag, add trash, remove trash", + flags: proton.MessageFlagReceived, + actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)}, + wantLabelIDs: []string{proton.AllMailLabel}, + }, + { + name: "sent flag, add trash, remove trash", + flags: proton.MessageFlagSent, + actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)}, + wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, + }, + } + + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass) + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, tt.flags, 1) + + require.True(t, (func() error { + for _, action := range tt.actions { + switch action := action.(type) { + case add: + if err := c.LabelMessages(ctx, []string{res[0].MessageID}, string(action)); err != nil { + return err + } + + case rem: + if err := c.UnlabelMessages(ctx, []string{res[0].MessageID}, string(action)); err != nil { + return err + } + } + } + + return nil + }() != nil) == tt.wantError) + + message, err := c.GetMessage(ctx, res[0].MessageID) + require.NoError(t, err) + + // The message should be in the correct labels. + require.ElementsMatch(t, tt.wantLabelIDs, message.LabelIDs) + + // The flags should be preserved after import. + require.True(t, message.Flags&tt.flags == tt.flags) + }) + } + }) + }) +} + +func TestServer_Import_FlagsAndLabels(t *testing.T) { + tests := []struct { + name string + labelIDs []string + flags proton.MessageFlag + wantLabelIDs []string + wantError bool + }{ + { + name: "received flag --> no label", + flags: proton.MessageFlagReceived, + wantLabelIDs: []string{proton.AllMailLabel}, + }, + { + name: "received flag --> inbox", + labelIDs: []string{proton.InboxLabel}, + flags: proton.MessageFlagReceived, + wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, + }, + { + name: "sent flag --> sent", + labelIDs: []string{proton.SentLabel}, + flags: proton.MessageFlagSent, + wantLabelIDs: []string{proton.SentLabel, proton.AllSentLabel, proton.AllMailLabel}, + }, + { + name: "received flag --> sent", + labelIDs: []string{proton.SentLabel}, + flags: proton.MessageFlagReceived, + wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, + }, + { + name: "sent flag --> inbox", + labelIDs: []string{proton.InboxLabel}, + flags: proton.MessageFlagSent, + wantLabelIDs: []string{proton.SentLabel, proton.AllSentLabel, proton.AllMailLabel}, + }, + { + name: "no flag --> drafts", + labelIDs: []string{proton.DraftsLabel}, + wantLabelIDs: []string{proton.DraftsLabel, proton.AllDraftsLabel, proton.AllMailLabel}, + }, + { + name: "forbidden: received flag --> All Mail", + labelIDs: []string{proton.AllMailLabel}, + flags: proton.MessageFlagReceived, + wantError: true, + }, + { + name: "forbidden: sent flag --> All Mail", + labelIDs: []string{proton.AllMailLabel}, + flags: proton.MessageFlagSent, + wantError: true, + }, + { + name: "forbidden: received flag --> inbox and all mail", + labelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, + flags: proton.MessageFlagReceived, + wantError: true, + }, + { + name: "forbidden: sent flag --> sent and all mail", + labelIDs: []string{proton.SentLabel, proton.AllMailLabel}, + flags: proton.MessageFlagSent, + wantError: true, + }, + { + name: "forbidden: received flag --> inbox and sent", + labelIDs: []string{proton.InboxLabel, proton.SentLabel}, + flags: proton.MessageFlagReceived, + wantError: true, + }, + { + name: "forbidden: sent flag --> inbox and sent", + labelIDs: []string{proton.InboxLabel, proton.SentLabel}, + flags: proton.MessageFlagSent, + wantError: true, + }, + { + name: "forbidden: received flag --> inbox and archive", + labelIDs: []string{proton.InboxLabel, proton.ArchiveLabel}, + flags: proton.MessageFlagReceived, + wantError: true, + }, + } + + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, pass) + require.NoError(t, err) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, err := stream.Collect(ctx, c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), []proton.ImportReq{{ + Metadata: proton.ImportMetadata{ + AddressID: addr[0].ID, + Flags: tt.flags, + LabelIDs: tt.labelIDs, + }, + Message: []byte(fmt.Sprintf("From: sender@pm.me\r\nReceiver: recipient@pm.me\r\nSubject: %v\r\n\r\nHello World!", uuid.New())), + }}...)) + if tt.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, proton.SuccessCode, res[0].Code) + + message, err := c.GetMessage(ctx, res[0].MessageID) + require.NoError(t, err) + + // The message should be in the correct labels. + require.ElementsMatch(t, tt.wantLabelIDs, message.LabelIDs) + + // The flags should be preserved after import. + require.True(t, message.Flags&tt.flags == tt.flags) + } + }) + } + }) + }) +} + +func TestServer_PublicKeys(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + if _, _, err := s.CreateUser("other", "other@pm.me", []byte("pass")); err != nil { + t.Fatal(err) + } + + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + intKeys, intType, err := c.GetPublicKeys(ctx, "other@pm.me") + require.NoError(t, err) + require.Equal(t, proton.RecipientTypeInternal, intType) + require.Len(t, intKeys, 1) + + extKeys, extType, err := c.GetPublicKeys(ctx, "other@example.com") + require.NoError(t, err) + require.Equal(t, proton.RecipientTypeExternal, extType) + require.Len(t, extKeys, 0) + }) + }) +} + +func TestServer_Proxy(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + var calls []Call + + s.AddCallWatcher(func(call Call) { + calls = append(calls, call) + }) + + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(_ *proton.Client) { + proxy := New(WithProxyOrigin(s.GetHostURL())) + defer proxy.Close() + + m := proton.New( + proton.WithHostURL(proxy.GetProxyURL()), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + // Login -- the call should be proxied to the upstream server. + c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer c.Close() + + // The results of the call should be correct. + user, err := c.GetUser(ctx) + require.NoError(t, err) + require.Equal(t, "user", user.Name) + }) + + // Assert that the calls were proxied. + require.Greater(t, len(calls), 0) + }) +} + +func TestServer_Proxy_Cache(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(_ *proton.Client) { + proxy := New( + WithProxyOrigin(s.GetHostURL()), + WithAuthCacher(NewAuthCache()), + ) + defer proxy.Close() + + // Need to skip verifying the server proofs for the proxy cache feature to work! + m := proton.New( + proton.WithHostURL(proxy.GetProxyURL()), + proton.WithTransport(proton.InsecureTransport()), + proton.WithSkipVerifyProofs(), + ) + defer m.Close() + + // Login 3 times; we should produce 1 unique auth. + require.Len(t, xslices.Unique(iterator.Collect(iterator.Map(iterator.Counter(3), func(int) string { + c, auth, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer c.Close() + + return auth.UID + }))), 1) + }) + }) +} + +func TestServer_Proxy_AuthDelete(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(_ *proton.Client) { + proxy := New( + WithProxyOrigin(s.GetHostURL()), + WithAuthCacher(NewAuthCache()), + ) + defer proxy.Close() + + // Need to skip verifying the server proofs for the proxy cache feature to work! + m := proton.New( + proton.WithHostURL(proxy.GetProxyURL()), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + // Watch for login -- the calls should be proxied. + var login []Call + + s.AddCallWatcher(func(call Call) { + login = append(login, call) + }) + + // Login -- the call should be proxied to the upstream server. + c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer c.Close() + + // Assert that the login was proxied. + require.NotEmpty(t, len(login)) + + // Watch for logout -- logout should not be proxied to the upstream server. + var logout []Call + + s.AddCallWatcher(func(call Call) { + logout = append(logout, call) + }) + + // Logout -- the call should not be proxied to the upstream server. + require.NoError(t, c.AuthDelete(ctx)) + + // Assert that the logout was not proxied! + require.Empty(t, len(logout)) + }) + }) +} + +func TestServer_RealProxy(t *testing.T) { + username := os.Getenv("GO_PROTON_API_TEST_USERNAME") + password := os.Getenv("GO_PROTON_API_TEST_PASSWORD") + + if username == "" || password == "" { + t.Skip("skipping test, set the username and password to run") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + proxy := New() + defer proxy.Close() + + m := proton.New( + proton.WithHostURL(proxy.GetProxyURL()), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + // Login -- the call should be proxied to the upstream server. + c, _, err := m.NewClientWithLogin(ctx, username, []byte(password)) + require.NoError(t, err) + defer c.Close() + + // The results of the call should be correct. + user, err := c.GetUser(ctx) + require.NoError(t, err) + require.Equal(t, username, user.Name) +} + +func TestServer_RealProxy_Cache(t *testing.T) { + username := os.Getenv("GO_PROTON_API_TEST_USERNAME") + password := os.Getenv("GO_PROTON_API_TEST_PASSWORD") + + if username == "" || password == "" { + t.Skip("skipping test, set the username and password to run") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + proxy := New(WithAuthCacher(NewAuthCache())) + defer proxy.Close() + + m := proton.New( + proton.WithHostURL(proxy.GetProxyURL()), + proton.WithTransport(proton.InsecureTransport()), + proton.WithSkipVerifyProofs(), + ) + defer m.Close() + + // Login 3 times; we should produce 1 unique auth. + require.Len(t, xslices.Unique(iterator.Collect(iterator.Map(iterator.Counter(3), func(int) string { + c, auth, err := m.NewClientWithLogin(ctx, username, []byte(password)) + require.NoError(t, err) + defer c.Close() + + return auth.UID + }))), 1) +} + +func TestServer_Messages_Fetch(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { + ctl := proton.NewNetCtl() + + mm := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + ) + defer mm.Close() + + cc, _, err := mm.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer cc.Close() + + total := countBytesRead(ctl, func() { + res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) + require.NoError(t, err) + require.NotEmpty(t, res) + }) + + ctl.SetReadLimit(total / 2) + + require.Less(t, countBytesRead(ctl, func() { + res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) + require.Error(t, err) + require.Empty(t, res) + }), total) + + ctl.SetReadLimit(0) + + require.Equal(t, countBytesRead(ctl, func() { + res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) + require.NoError(t, err) + require.NotEmpty(t, res) + }), total) + }) + }) + }, WithTLS(false)) +} + +func TestServer_Messages_Status(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { + ctl := proton.NewNetCtl() + + mm := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + ) + defer mm.Close() + + statusCh := make(chan proton.Status, 1) + + mm.AddStatusObserver(func(status proton.Status) { + statusCh <- status + }) + + cc, _, err := mm.NewClientWithLogin(ctx, "user", []byte("pass")) + require.NoError(t, err) + defer cc.Close() + + total := countBytesRead(ctl, func() { + res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) + require.NoError(t, err) + require.NotEmpty(t, res) + }) + + ctl.SetReadLimit(total / 2) + + res, err := stream.Collect(ctx, cc.GetFullMessages(ctx, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) + require.Error(t, err) + require.Empty(t, res) + + require.Equal(t, proton.StatusDown, <-statusCh) + }) + }) + }, WithTLS(false)) +} + +func TestServer_Labels_Duplicates(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + req := proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeLabel, + } + + label, err := c.CreateLabel(context.Background(), req) + require.NoError(t, err) + require.Equal(t, req.Name, label.Name) + + _, err = c.CreateLabel(context.Background(), req) + require.Error(t, err) + }) + }) +} + +func TestServer_Labels_Duplicates_Update(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + label1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeLabel, + }) + require.NoError(t, err) + + label2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeLabel, + }) + require.NoError(t, err) + + // Updating label1 with label2's name should fail. + _, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{ + Name: label2.Name, + Color: label1.Color, + }) + require.Error(t, err) + + // Updating label1's color while preserving its name should succeed. + _, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{ + Name: label1.Name, + Color: "#f00", + }) + require.NoError(t, err) + + // Updating label1 with a new name should succeed. + _, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{ + Name: uuid.NewString(), + Color: label1.Color, + }) + require.NoError(t, err) + }) + }) +} + +func TestServer_Labels_Subfolders(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + + child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + ParentID: parent.ID, + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + require.Equal(t, []string{parent.Name, child.Name}, child.Path) + + child2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + ParentID: child.ID, + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + require.Equal(t, []string{parent.Name, child.Name, child2.Name}, child2.Path) + }) + }) +} + +func TestServer_Labels_Subfolders_Reassign(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + parent1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + + parent2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + + // Create a child initially under parent1. + child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + ParentID: parent1.ID, + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + require.Equal(t, []string{parent1.Name, child.Name}, child.Path) + + // Reassign the child to parent2. + child2, err := c.UpdateLabel(context.Background(), child.ID, proton.UpdateLabelReq{ + Name: child.Name, + Color: child.Color, + ParentID: parent2.ID, + }) + require.NoError(t, err) + require.Equal(t, []string{parent2.Name, child.Name}, child2.Path) + + // Reassign the child to no parent. + child3, err := c.UpdateLabel(context.Background(), child.ID, proton.UpdateLabelReq{ + Name: child2.Name, + Color: child2.Color, + ParentID: "", + }) + require.NoError(t, err) + require.Equal(t, []string{child3.Name}, child3.Path) + }) + }) +} + +func TestServer_Labels_Subfolders_DeleteParentWithChildren(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + + child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + ParentID: parent.ID, + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + require.Equal(t, []string{parent.Name, child.Name}, child.Path) + + other, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ + Name: uuid.NewString(), + Color: "#f66", + Type: proton.LabelTypeFolder, + }) + require.NoError(t, err) + + // Get labels before. + before, err := c.GetLabels(context.Background(), proton.LabelTypeFolder) + require.NoError(t, err) + + // Delete the parent. + require.NoError(t, c.DeleteLabel(context.Background(), parent.ID)) + + // Get labels after. + after, err := c.GetLabels(context.Background(), proton.LabelTypeFolder) + require.NoError(t, err) + + // Both parent and child are deleted. + require.Equal(t, len(before)-2, len(after)) + + // The only label left is the other one. + require.Equal(t, other.ID, after[0].ID) + }) + }) +} + +func TestServer_AddressOrder(t *testing.T) { + withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { + withUser(ctx, t, s, m, "user", "email@pm.me", "pass", func(c *proton.Client) { + user, err := c.GetUser(context.Background()) + require.NoError(t, err) + + primary, err := c.GetAddresses(context.Background()) + require.NoError(t, err) + + // Create 3 additional addresses. + addr1, err := s.CreateAddress(user.ID, "addr1@pm.me", []byte("pass")) + require.NoError(t, err) + + addr2, err := s.CreateAddress(user.ID, "addr2@pm.me", []byte("pass")) + require.NoError(t, err) + + addr3, err := s.CreateAddress(user.ID, "addr3@pm.me", []byte("pass")) + require.NoError(t, err) + + addresses, err := c.GetAddresses(context.Background()) + require.NoError(t, err) + + // Check the order. + require.Equal(t, primary[0].ID, addresses[0].ID) + require.Equal(t, addr1, addresses[1].ID) + require.Equal(t, addr2, addresses[2].ID) + require.Equal(t, addr3, addresses[3].ID) + + // Update the order. + require.NoError(t, c.OrderAddresses(ctx, proton.OrderAddressesReq{ + AddressIDs: []string{addr3, addr2, addr1, primary[0].ID}, + })) + + // Check the order. + addresses, err = c.GetAddresses(context.Background()) + require.NoError(t, err) + + require.Equal(t, addr3, addresses[0].ID) + require.Equal(t, addr2, addresses[1].ID) + require.Equal(t, addr1, addresses[2].ID) + require.Equal(t, primary[0].ID, addresses[3].ID) + }) + }) +} + +func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := New(opts...) + defer s.Close() + + m := proton.New( + proton.WithHostURL(s.GetHostURL()), + proton.WithCookieJar(newTestCookieJar()), + proton.WithTransport(proton.InsecureTransport()), + ) + defer m.Close() + + fn(ctx, s, m) +} + +func withUser(ctx context.Context, t *testing.T, s *Server, m *proton.Manager, username, email, password string, fn func(c *proton.Client)) { + _, _, err := s.CreateUser(username, email, []byte(password)) + require.NoError(t, err) + + c, _, err := m.NewClientWithLogin(ctx, username, []byte(password)) + require.NoError(t, err) + defer c.Close() + + fn(c) +} + +func withMessages(ctx context.Context, t *testing.T, c *proton.Client, pass string, count int, fn func([]string)) { + user, err := c.GetUser(ctx) + require.NoError(t, err) + + addr, err := c.GetAddresses(ctx) + require.NoError(t, err) + + salt, err := c.GetSalts(ctx) + require.NoError(t, err) + + keyPass, err := salt.SaltForKey([]byte(pass), user.Keys.Primary().ID) + require.NoError(t, err) + + _, addrKRs, err := proton.Unlock(user, addr, keyPass) + require.NoError(t, err) + + fn(xslices.Map(importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, count), func(res proton.ImportRes) string { + return res.MessageID + })) +} + +func importMessages( + ctx context.Context, + t *testing.T, + c *proton.Client, + addrID string, + addrKR *crypto.KeyRing, + labelIDs []string, + flags proton.MessageFlag, + count int, +) []proton.ImportRes { + req := iterator.Collect(iterator.Map(iterator.Counter(count), func(int) proton.ImportReq { + return proton.ImportReq{ + Metadata: proton.ImportMetadata{ + AddressID: addrID, + LabelIDs: labelIDs, + Flags: flags, + Unread: true, + }, + Message: []byte(fmt.Sprintf("From: sender@pm.me\r\nReceiver: recipient@pm.me\r\nSubject: %v\r\n\r\nHello World!", uuid.New())), + } + })) + + res, err := stream.Collect(ctx, c.ImportMessages(ctx, addrKR, runtime.NumCPU(), runtime.NumCPU(), req...)) + require.NoError(t, err) + + return res +} + +func countBytesRead(ctl *proton.NetCtl, fn func()) uint64 { + var read uint64 + + ctl.OnRead(func(b []byte) { + atomic.AddUint64(&read, uint64(len(b))) + }) + + fn() + + return read +} + +type testCookieJar struct { + cookies map[string][]*http.Cookie + lock sync.RWMutex +} + +func newTestCookieJar() *testCookieJar { + return &testCookieJar{ + cookies: make(map[string][]*http.Cookie), + } +} + +func (j *testCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.lock.Lock() + defer j.lock.Unlock() + + j.cookies[u.Host] = cookies +} + +func (j *testCookieJar) Cookies(u *url.URL) []*http.Cookie { + j.lock.RLock() + defer j.lock.RUnlock() + + return j.cookies[u.Host] +} + +func must[T any](t T, err error) T { + if err != nil { + panic(err) + } + + return t +} + +func elementsMatch[T comparable](want, got []T) bool { + if len(want) != len(got) { + return false + } + + for _, w := range want { + if !slices.Contains(got, w) { + return false + } + } + + return true +} diff --git a/server/settings.go b/server/settings.go new file mode 100644 index 0000000..2893986 --- /dev/null +++ b/server/settings.go @@ -0,0 +1,21 @@ +package server + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetMailSettings() gin.HandlerFunc { + return func(c *gin.Context) { + settings, err := s.b.GetMailSettings(c.GetString("UserID")) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "MailSettings": settings, + }) + } +} diff --git a/server/testdata/text-plain.eml b/server/testdata/text-plain.eml new file mode 100644 index 0000000..bfcdb6e --- /dev/null +++ b/server/testdata/text-plain.eml @@ -0,0 +1,6 @@ +To: recipient@pm.me +From: sender@pm.me +Subject: Test +Content-Type: text/plain; charset=utf-8 + +Test \ No newline at end of file diff --git a/server/users.go b/server/users.go new file mode 100644 index 0000000..7e7e420 --- /dev/null +++ b/server/users.go @@ -0,0 +1,21 @@ +package server + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetUsers() gin.HandlerFunc { + return func(c *gin.Context) { + user, err := s.b.GetUser(c.GetString("UserID")) + if err != nil { + c.AbortWithStatus(http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{ + "User": user, + }) + } +} diff --git a/share.go b/share.go new file mode 100644 index 0000000..3c12358 --- /dev/null +++ b/share.go @@ -0,0 +1,39 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) ListShares(ctx context.Context, all bool) ([]Share, error) { + var res struct { + Shares []Share + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + if all { + r.SetQueryParam("ShowAll", "1") + } + + return r.SetResult(&res).Get("/drive/shares") + }); err != nil { + return nil, err + } + + return res.Shares, nil +} + +func (c *Client) GetShare(ctx context.Context, shareID string) (Share, error) { + var res struct { + Share + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/drive/shares/" + shareID) + }); err != nil { + return Share{}, err + } + + return res.Share, nil +} diff --git a/testdata/body.pgp b/testdata/body.pgp new file mode 100644 index 0000000..70ac4fa --- /dev/null +++ b/testdata/body.pgp @@ -0,0 +1,82 @@ +-----BEGIN PGP MESSAGE----- +Version: ProtonMail + +wcBMAwhkvvXurdhrAQf9GTfrNtYdoXtqGDLEEdTZMnnZ56rWUYJASCBnlTZt +QeKMG1t4S+YurOWrCJWiRQoYsS67ljHzs0cehzchWlTs0dAM5JKBmtFW0ZmV +U9b8vtXwLCHu2S9mH/Yz9MJaR1w8/M5IED4cbFR4If0sSyAEIuXGTBQF3LNc +5bajT9qN1EgeL0ILMEBstLyO95QdPqNkOTwUrzxEEmQXkzjWw0rkdin19UfC +s6Z7Ej2YZHQBUH8VAYQDnHKccvYDWcCpO4+r6MlEXQEDGbwt6W3wjUO7yu7R +RvK1A8l8JUOCGf0PJnXHr+HzsnkIbmZ9t4AJn2iSvS+MceeYBnGKFNK7Mmi+ +0NLLegFp8UtMpYZTyY/WnjvpG9KEi0RxIKAJOuwZDaGJXhIoyqOqLxiHEnbZ +JhDXK5nuJ5UAsrmsbjvzreIbNw8ToEMPd/9N3WwJB28gYNVMJdYluQACfvI1 +mk+yd1JR2dPfyP+SA7KMaUPXwsMWN86ViLZE0h+IANUkxAGGMAPU2ikhlbi0 +qZzRLx146sKHncZNfn3ge2noFNeJLSUvm5uzOxIjOgjKLyA33SCLWG7GYzaq +4vURCczZ7kJ/Y5M7P53y1uzMdp0AFsqhD0zMQZ35qmGHpmpN44srVxLk0piw +TchQcgs/F5lnW+HFsiILuWjwts1hbXeEeMugDNiuqYazP3qHmR8aSDyFZusB +fUuNegtuDXaI8HD6mIscToO1Jc2os0fPk/FMeo4vVFBQCv7hMk/hx8ytcDh7 +c1963+Rjopq97F8GhDc5S/7xO3oCd+1WSFN7BfyNL0NmfUWM4Fur67Ru0VmW +QQmesAeYnGzvCEfGA+tDZYp6gu4/HBfHPMjZs+27MgnkvdOHjLaTQRSmKLeo +Oh3BECB9lT0qO85k8+wPZzFQvHpDOwxbW+7u5IZBolGNmO3erE/5NitTWOsJ +qfbgYg34+3Juim5E4n8a5npcSMewea7IFtlebOoKSTJ6F8SX4lhky/HdcCjR +jzHRh+bpnLEApufKfRplEQLOZoxxLJM++ey2MT32gZLjAXMjAt5eRyhOOrQe +v6GdSDbDtMKWhEkTAbBbh2lgi8ZwW1w05pGXmeXXxP4wjBeG1hj4IfT1tFa3 +w4IlOU0wzk+jaMBA0Mg4I5VUGNIvhUzU6td8GYN+ZCOQedSJZ/AkkJxeXRLD +Sy0NTwW+/tyT9WDL7hfg8T2Dh9dzdXuHS23FAZ2irrzy/bIB/tAOtdqGgf7e +KW1JWu5TjKzcMx2aP+ZpOJG6ieAqbU0xPCU2d7HdhHSV0tSwiDfEnV2u2lOJ +HtInII03NOhmOe1MzWmmztBw4e1bEtVI62rRup7pFPubIlOiC1fmWzyFmlvY +3g/QF69j2Lg7UjBBlZ1IZ17TMe3oCr5hNZED5vdz8i6ECbhsaGFxWddAoyea +V2TEW19pAc5/fs+v87ZTRZzoMXK+uTH+wVWT4AYYZ3GUWod3jiWcjeQqjNB+ +VhkUSNbBYdJ7CBpUu5eb82N5ZE6v8jV1dw+kKgRE4DRzASP7Z/2yUi+ze2D5 +i0TKyFtqMwfw3ptckUAWE7r57B7dbNY8b5woKaVBCkVyEnY0eDIqQUU/RoZq +tgKki+4KgwNlIbT6OpxPVAog6fY6M3/3crO52cqWQPX/Uhp0J5rZXLEMRdmX +M/7UvFcW96cYtOeCucux9WnNPWqSJ4ddzwt6lxqptsZU7jZS0IGRLqhAn23e +u1/NdQpK91YIDeFIlMdyc9/h8VEnaSU/LE1+28INGJsET1dUcw4Qwxb1uOur +kleYddoFs4tgi44uxNq3o+Pw9EfFFPu2t1U6FESN4qc+aJ/Sv8Oipi2fyDVv +7iD5XrJ4Oe4a2HXZLfvAaQig4cfa0x4Hqiqp/YSjWvyIZpIoiI+AKDS3bNXs +e47mADlHpv4l/rKA/Q5Qjs8zRB74So0Y0oG1qXNRbcPADUrSyMAxqInz/Q3c +9CA35GSp7js2V5rKyv2a11z6VdLtPw/e3KsEvdVmx/7DxSeGvXO5lwN6RluK +DJgIl9/7dVLuNvuVXk+tDCrs2dUHOf+a5N34x9B9AQGT3Ixo80le8e9+U5Bc +r30iycgQcSISUw5NKCpOTj5BuabYfDQYKCA28l+lrEGq1+q4vumMQgGxEJJn +gTqJTyQOE0abbohHl4/+TkKWNdWTwth0eZHHbs0Xrjf1mq7GnCh49QxkIxQt +VDwxE/BkiO5SkU0CWdMGLL2LtaincqGtdIlZB7emRB8qdZOQt+5wp+XhiXbG +DHoPV3739XzxojZUrOzyh6aoezXmntoHqucr/V85BwSJ5F8rjZiHQDoPuwmf +oE3ABLMTC3T9rO2MfxepGbFxNmBaXUrO/muChvq29NO83o+KdPP6YJVu/qAL +utMeEhBPgd+if956Ph2twjM1qKO5Hp1Z178KM+2R2PDz1kNmKtsNfiNJhrQ9 +1qmLgNCp9qZCFqExHRHcmpgv94KuUYWyONZYBKBCAJKAPM5IRVsryOYUbUlS +jeVO3aZES89nCh0zF4XiJ0LaI6kxCCzje2+cZBY5jxiUw4I5F4wOawCwSsv2 +mLAuhBV/H0KWq9o0e2RlXMRtrDTgfdnRGHnfXqbQDIgsgQmga7zUej9cVOz2 +HA6Oml/h8TjJPQVXkrTuKN4NohWTsz8QX8ISdXuaWoMn9MVq57aYweFV97tT +UdzLvJCioq7+AiFu2WN9NmU/bmN2ye5gfn8Qw7eKTfD3VJFkF8Y6lw6RLm0D +loMHIv9gaECLcxiJyINjttH+Shx2Nl2Pj5WOgixK+SnClAQ5Uzk7IZTj4CdJ +YbIoXg3ukyTeGUaG2flwSbsRUBz0PFPAznTtovtgY55tne4EI7qpdwbo5+fi +CBeGylEoozfuKcsnaOJJ4F6pdeBuhwoAV4drvI/qNFRhs82pV85n9aP+7XM/ +f16vdig5rn1zodef/mlz989pwfDwjfwabQliNAfczX+rAur8CQcUs62qs41x +vQjeSc30NQzR1+B2MlsQtuKxk7HJUlp4tpuldAfC1ffwSThfXKqjSpbDOTRU +xVz1xlnjZoQQPhmrNCj0CqpKbqsjq5qEyvgM4llvXedw7wDdXIp52fIlwwO2 +CwgXS3sTAa46YWkbqG+iIEFowjUse0T8tK/+3dtnYkS5Hen2rp1lUdl/mm2Z +CPqJjOSXGs5Iz+4A2JOSBscYX7GDh/IcF0d9r55K4f5JvJqwYr2MnrQcOJLJ +e3EC6weTi0d1a59eafY7m6b4cS2zRQPUtvcOVzkVXbvXMyjWi+WDJYl8EFPr +mIOa/Ij7FCiuEOMlazwN5Ot8AIh0qSV9VubqVPUEyc6DozjqdxDL0ZLJxlTy +fWTwR//spAgr/pRqhZxNCvfHFn1NeF7TFbDAPZojfEChF9n1u951jgq5sZ01 +aXAiIo5bwVPc/FxHLEf0Zu5f/SS5ZNgIVquJGDNwPVD5S6o0h8w+UjT66Mkp +X6wzIP9u4BAkvTHJVwkk79F/MmRPYYjwQooRHcnI+hD1ezliXZ9vgxysFR5g +QqftD2tvkwPbxFXOer1mD5MSSMXWGXlhpLBPxiLFRurReBCabL53818Q4P80 +Ky8rBUnw5p8XTdQ/fKaMfpQGHzmgiBaOU+sQnWR2IcErpuEWJB2qKHduUgep ++if+VOXi3xe9oE/hc7GqX7kV7zmHPBgV9ixczvgouEbRsqSd1LlGYIOd7Suf +1cl7IoxEc2bnbZMdbWPQwyM10D+YDTP+hRs+fjAYoQA+iKy3fg6zHDZKo9BI +6NZQNSGmVn5apP6EpUsFLK9qWltryTTXabbLEdVGOHLSgX3F30wJt1IwtSoV +BabBFJARq2JTeYEeOyQryyI18nSNQ3AWn30lbWo/Zm5+MxXCnsm46uvagPBM +kbMJT/GbsMVU+gVj6BwJz4wH9AY1K0LZL2+ip8/zkwTMp5kXpQTkslEv3spQ +f+wHEduwgfsANaoELc4L3NNcZyR8BD3AWilIEHwmHSoAxNXMjZzIC8JidttG +18GuoiFxf4PG6QR810y1NFw+eR3cFEOa6D5OZP8f8xlHDGJSCv+Rokkc4JLo +1kEGMbYY6IFW9DGaM3s3NOKO14U6YJQUpae23MF/+lTP5LbQadeWh4myU5Wg +TefM7rv9dl1yPZnCNINWhSuxshrBse8wbGwUJZ9Ix9KwnU8zR/HoDm0QvPt4 +XYXTvVFGF6ynUaDe2xFfgmBMHNuPPJ+pv/rJoGqZ5th2HOhWVuuTRWKK1JBT +IKgTnkDXOXDDdqHohr9r5AeQ10thqi0Pi3PzcynNH706qcEdAK9G/dYJ7lXg +FDNPwRJYrTrvLwEEA+7rYEoq3PajLndz96hou2sRX28/6J67VBpiOkfcgk9t +SJ6Srr3Tn/Ud4YnmcB1LQx0goQC4Bmf/Q8iPHnRimi16g8vSl2wTx2EEr+40 +RZcJWj68/3EETH5qPgFdp4m6teRnB7tE6wudCwsNsJ51LcGr0aKxvDVSZyAA +jM09plC+5rUAdYF5teVfZZFQLt/PFAfapvvzEWE= +=zZ65 +-----END PGP MESSAGE----- + diff --git a/testdata/prv.asc b/testdata/prv.asc new file mode 100644 index 0000000..66193be --- /dev/null +++ b/testdata/prv.asc @@ -0,0 +1,64 @@ +-----BEGIN PGP PRIVATE KEY BLOCK----- +Version: OpenPGP.js v4.10.10 +Comment: https://openpgpjs.org + +xcMGBF4yECUBCACmt9I8R1+ibe0pHa/PMC1Zs07BFSjTB4k2B8EZJhc3dlgU +WoNj06HKAKTnF0tCJfojb9Hhhns9E3da+/pr2mTeGLGCmxVlHmS7vt5MKezh +rGmqT9QMJAgjgDlL8+ecWaapRyIQ4NXBX7H3DPuUDSVBxISSbXLxfORiYP5E +DueF6aLLMrJFgFamFq5kgpFvXYZmPv+h/VoLP+ZLxnPfS92W3jJ3Y6ByZcC9 +bsHiqBFAnGDnh0nIW109rEZRa8vKeSJL/48hUSxucvdAGWQyVpA2vhiWXzba +3O1gmjcphVrQCCgSyvyDvYF/uS3mR1F8d//BpKuYKInq6+Y5MPeKRJGVABEB +AAH+CQMISThK/g0sWIlgIp1NC8BxDEFdWTKFvOVMFrmX11zKW5gBJClW6WDE +sgbfZi62hLUBAJCwGfH4jH+737RIzFyuNvLTU0Kh7hcOBiEdORIQ6GPq0PFk +fDdy4zqd5pFdvvG7bAF+SPWSh9ydFd6kkoelsg0vqUMQSD8cLRw3ESeyzdh1 +Rw1sidfpReHO8Q27G95bX1nvK+tBlFyUhJAVP/kfDw0Pn21L45QPuO5ueB1R +vbuihvHeFydZ6FmsNOI5avxHFmxOqO9qP/Bxs4R14JLMsQFt//n+Dp5pnF4F +GzTTYHevY9SCjlMkn1x25FTmJB+aMDJrN+yrEw/oCsDF8Y5ZfueHaqDY46Lv +oW4tkDExK045BDNH5yE0Z/3zjz2FpUfLiLqU8YnR72SoxHaRRuIdDlNIuiGc +FYy0yZsj30SVSVDKL/A+Nu1Gz2GvHociTqqJQiyBXcp5C0DjEmz97QP8y1o3 +atg71RAIOCWReu9YqbAWWdBZo0dannohPHTUeHMhh9bd5GTamG/xf5SrTBTg +ooVWwlGuF8/NQpYGqjvENgr5a5Q0SqWUm31CeYcHLI70TooTbmIH6gvlUax1 +O9VmzUpX9GlZQ1fwkHp6WAG+0cwMie2Nlwt0Ul4WjGQS0HeyYKGMriY/vMIx +IFScwIVzWY3/nbVS5Z7BGrE3sNEOP313o603DKq+lasT3WKpRYQ9EJxvPdjW +En0IcSeZnlLBM3ZAcPRlh4dmVQwEYx8nasmbIBj80nANA96jMT2X8Sft1zEN +OpLPR86l8eRAkmbdAZ+GZIl6xkxsdxc+cad2OJubs+ze6CbjfCWE6zICUvE1 +PQr1fARRmhyL6Ixpio8TTaSztaBh/QjF6vcU2dpApzri/k+w+UX9ZzP0LNbw +VncvsJNMEWNzFZ0pmbfZX0rZwb9C+1yKzSdzY2hpem9mcmVuaWNAcG0ubWUg +PHNjaGl6b2ZyZW5pY0BwbS5tZT7CwHYEEAEIACAFAl4yECUGCwkHCAMCBBUI +CgIEFgIBAAIZAQIbAwIeAQAKCRDssm2jN69QH0w+B/9dpwofZoAMu7eeVS5B +tkvLiCWJqWtoBIU2TfB1nAzIbOJA5cWegobKsEtBVI//QAQBwcjg+BjjXmGM +KkmO6suDARMrATtct+G2kUl7FpRDE47okq0s+2KJb7bAaPQoBOx5xwFQM3Tj +JkD9C+1xSJIrcIgpk1Rs129cNZXKXNc01v02xTrszYnbLqvneYFY4Qt1AUTP +bB9us3c6nx0dDq0phGJwRbUOptNZbQrrJ3F3zmVKwGZLju0L0Gmy/F/AOMk4 +S3hh/LOTcmSJ0ytZPtTkrTUKmCqDkEspTO4c17y1ffS/4LfzdnPFJJEXjvYU +DQNtphFfgsbQeBmuzF2yQTyCx8MGBF4yECUBCADTP/kymrU/DLbGK6kgiUAB +UU4zH7Rq6u1NVqKwdaBKOulMKst4QSlVfixI2IDjG2JgUJbCDjhqmgQ3AbDz +Z7xOxUqscvM9xsVBbZM5KW+k5cOeAPGNu/GEz62gz/sUTQ5ZGLMjX+C31/3b +olKNWuke46mBmPIcv0of7/izanZSRqUeJ4+KsWQjorPmmurqt3TCRq1h3dlm +itHQTlQLn9EWRvvTIQagzh5bma7nfwIdnTLfRQW4JX/W6t09O3wj+g5t/X4S +dTbXTHnjkLYahXiFDII+2KEcYGWOrs3HeJRb8GEhuOI0g5yK2ezX7RLpzSjj +nql4oiQvczPO73fC4N83ABEBAAH+CQMIVc335hmgh/5gkyB7ZXJRkmU5v9yw +CXscCDvKBTEFnVPteLUNi7D41TRWwQMIlahu106doJavog9PSwft+tQ590pl +x7BwlE5+Rfr48svaHkkm8/AoewScpkqIH0Z/m2/LSpta3Lpjlj0ea8KUkk0V +BM/KzwoF6hz85ZglT/s49MohLT6bowhgTE/ZeoLvJJ+NN40KPf9+2ZU3vR8Q +NpnzqAXIg5iLuUKrAkic0r3DeOUbKebznMJevN6l6DNBQk1BpRI/talsID99 +rk/OaQ8fSXC63NLmXNBg7Oig9iyQJYnqc0le3d3QztzQLgIg79S4PVlQuQvA +truSPSrCaLcmqarjuqKDEA4zUzYcGDSKbLXn5JUI9Y74k+CK/IN5OyocwDn3 +THAY63+rLgjEyzd3MKtpS2G6cv7FQVVNvk9j73v6zvlrWXYD8BClklLAhiH5 +FjCuEnNf3EV6r6ztMp4/UJoqS5N+qrQHLR8upufHoEmmhtL0HLkcqkT7O/9O +LJEKREqcM7w8oWogSQLHbZ5XxWXodDH+smJbu8aySsJT0EED8agiQaN7WTVu +yMGYUms/U/SoCvXPLYnnFzSf4W6xiD6o5kreaFg5OvFlyxQevtcbcU3vXSYW +QDXqr69Yp4lZXm0gOoDLBCCOkqhQcMJka58eP0hSQkCDsMiwxSk86JRtdw+W +H8s1KE9/noKipu2g3v79n1bY0SzkVchGtPi6iV4SarUYEfchkxp8ZlEcnbOy +mJJaIIHoseZ23un/78Bu+YUDJ3kjeBSwIKYrCe1+51Z+xXnS5/tanX2LFhB6 +gJ7zLABynXrzu7UHxb3zWmHLZxeiYpVYFwrntTPYd7peOiSc/NDWiYLK1SwK +nMnpdnLi0/LLuRMsWKKuIITw8sGfki1oxIh9D/6bWEe6eDJIJnAfieGbwsBf +BBgBCAAJBQJeMhAlAhsMAAoJEOyybaM3r1AfgywH/01kOihA5/Q8doMipNkZ +az3+4ZcAnPeqnIx6ba8xQTLL38Z6xy7SrTQyCLv01dMJVbRqie0ypk4Zeyxw +CK7mMqJAq5vMuj/voKCjFZnW3wszxRV3p+U9/SlPn9Rirg2DVFwjScRYro4P +3Tml0oMmFD2jD7QkATwWdhYjTKwET8eCtv7CciKe9EOad6b4vLCiXpT6TiuS +MPHS4iUgVMKL4jhVAQoqZDOMMN+odt7yKzhtUaF8VQKLwwHh/JC8BTuweNJN +5doj2cZGjeDu8HQzW/kDSVGQA05rOfZBX4hk8Cpm82ChsjJA7vb3AQFir3CD +qF70cirPDl5bNhyYkJm6Asw= +=82rN +-----END PGP PRIVATE KEY BLOCK----- diff --git a/testdata/pub.asc b/testdata/pub.asc new file mode 100644 index 0000000..bc16066 --- /dev/null +++ b/testdata/pub.asc @@ -0,0 +1,35 @@ +-----BEGIN PGP PUBLIC KEY BLOCK----- + +xsDNBGCwvxYBDACtFOvVIma53f1RLCaE3LtaIaY+sVHHdwsB8g13Kl0x5sK53AchIVR+6RE0JHG1 +pbwQX4Hm05w6cjemDo652Cjn946zXQ65GYMYiG9Uw+HVldk3TsmKHdvI3zZNQkihnGSMP65BG5Mi +6M3Yq/5FAEP3cOCUKJKkSd6KEx6x3+mbjoPnb4fV0OlfNZa1+FDVlE1gkH3GKQIdcutF5nMDvxry +RHM20vnR1YPrY587Uz6JTnarxCeENn442W/aiG5O2FXgt5QKW66TtTzESry/y6JEpg9EiLKG0Ki4 +k6Z2kkP+YS5xvmqSohVqusmBnOk+wppIhrWaxGJ08Rv5HgzGS3gS29XmzxlBDE+FCrOVSOjAQ94g +UtHZMIPL91A2JMc3RbOXpqVPNyJ+dRzQZ1obyXoaaoiLCQlBtVSbCKUOLVY+bmpyqUdSx45k31Hf +FSUj8KrkjsCw6QFpVEfa5LxKfLHfulZdjL3FquxiYjrLHsYmdlIY2lqtaQocINk6VTa+YkkAEQEA +Ac0cQlFBIDxwbS5icmlkZ2UucWFAZ21haWwuY29tPsLBDwQTAQgAORYhBMTS4mxV82UN59X4Y1MP +t/KzWl0zBQJgsL8WBQkFo5qAAhsDBQsJCAcCBhUICQoLAgUWAgMBAAAKCRBTD7fys1pdMw0dC/9w +Ud0I1lp/AHztlIrPYgwwJcSK7eSxuHXelX6mFImjlieKcWjdBL4Gj8MyOxPgjRDW7JecRZA/7tMI +37+izWH2CukevGrbpdyuzX0AR7I7DpX4tDVFNTxi7vYDk+Q+lVJ5dL4lYww3t7cuzhqUvj4oSJaS +9cNeFc66owij7juQoQQ7DmOsLMUw9qlMsDvZNvu83x7hIyGLBCY1gY1VtCeb3QT7uCG8LrQrWkI9 +RLgzZioegHxMtvUgzQRw8U9mS8lJ4J2LaI3Z4DliyKSEebplVMfl53dSl1wfV5huZKifoo9NAusw +lrRw+3Ae+VZ0Obnz14qmyCwevHv6QlkXtntSY1wyprOvzWiu8PE9rHoTmwLI8wMkbiLdFVXCZbon +/1Hg0n1K0fv1A8cIc5JSeCe3y8YMm7b5oEie/cnArqDjZ8VB/vm5H9zvHxfJCI5FwlEVBlosSpib +Tm/1fSpqDgAmH7IDe3wCY8899kmfbBqJzr+5xaCGt+0mgC8jpJIEIKHOwM0EYLC/FwEMAKtvqck9 +78vAr1ttKpOAEQcKf1X04QLy2AvzHGNcud+XC1u0bHLm3OQsYyLaP3DVAvain6vrVVGiswdsexUI +yIEpBTo+9Rco7MtwwESfxG10p2bbd8q74EaJZkt/ifL6oxEYgp8tCgAB6tqGoXCmkG0nKszrrTTz +Lo/3bHjzfxF01oGDNlQVGVwW+8d5tjV5vowxeSjmdIZXJPNep4Lah/xFisWb71VwdzVEaOi6k7rQ +J5k+Dp1wrCqW1H5RZZt6dGweU4LbuTYBWtnw/2YKz+hBOYGDzil9hqTG9fRXu31d4xOZxuZkv61R +3DWrxuECKUHgJvFaao0KSnBDa/T/RMJ9Y/KQ0bx0zXOTtoDOhOhpMA8JUTMfWb3Uul50ikxLI5EJ +xnBroy2bLLaRW6ijMgpdnZRAtmhssHipOisxXoxiWMoRfJBR01DhbmSQPTjpsjqM2Z24hPcKN+sf +9kCKTmaJ2hbOfurriPmM0GHdgewbf5cemKgqVaPfhvyBXhnRjwARAQABwsD8BBgBCAAmFiEExNLi +bFXzZQ3n1fhjUw+38rNaXTMFAmCwvxcFCQWjmoACGwwACgkQUw+38rNaXTNTSgwAqomSuzK80Goi +eOqJ6e0LLiKJTGzMtrtugK9HYzFn1rT7n9W2lZuf4X8Ayo9i32Q4Of1V17EXOyYWHOK/prTDd9DV +sRa+fzLVzC6jln3AKeRi9k/DIs7GDs0poQZyttTVLilK8uDkEWM7mWAyjyBTtWyiKTlfFb7W+M3R +1lTKXQsn/wBkboJNZj+VTNo5NZ6vIx4PJRFW2lsDKbYJ+Vh5vZUdTwHXr5gLadtWzrVgBVMiLyEr +fgCzdyfMRy+g4uoYxt9JuFvisU/DDVNeAZ8hSgLdI4w65wjeXtT0syzpL9+pJQX0McugEpbIEiOt +e55OL1C0hjvHnsLHPkRuUOtQKru/gNl0bLqZ7mYqPNhJbh/58k+N4eoeTvCjMy65anWuiWjPbm16 +GH/3erZiijKDGYn8UqldiOK9dTC6DbvyJdxuYFliV7cSWIBtiOeGrajxzkuUHMW+d1d4l2gPqs2+ +eT1x4J+7ydQgCvyyI4W01xcFlAL70VRTlYKIbMXJBZ6L +=9sH1 +-----END PGP PUBLIC KEY BLOCK----- diff --git a/ticker.go b/ticker.go new file mode 100644 index 0000000..8791093 --- /dev/null +++ b/ticker.go @@ -0,0 +1,58 @@ +package proton + +import ( + "math/rand" + "time" +) + +type Ticker struct { + C chan time.Time + + stopCh chan struct{} + doneCh chan struct{} +} + +// NewTicker returns a new ticker that ticks at a random time between period and period+jitter. +// It can be stopped by closing calling Stop(). +func NewTicker(period, jitter time.Duration) *Ticker { + t := &Ticker{ + C: make(chan time.Time), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + + go func() { + defer close(t.doneCh) + + for { + select { + case <-t.stopCh: + return + + case <-time.After(withJitter(period, jitter)): + select { + case <-t.stopCh: + return + + case t.C <- time.Now(): + // ... + } + } + } + }() + + return t +} + +func (t *Ticker) Stop() { + close(t.stopCh) + <-t.doneCh +} + +func withJitter(period, jitter time.Duration) time.Duration { + if jitter == 0 { + return period + } + + return period + time.Duration(rand.Int63n(int64(jitter))) +} diff --git a/undo.go b/undo.go new file mode 100644 index 0000000..9ffb391 --- /dev/null +++ b/undo.go @@ -0,0 +1,28 @@ +package proton + +import ( + "context" + "runtime" + "time" + + "github.com/bradenaw/juniper/parallel" + "github.com/go-resty/resty/v2" +) + +func (c *Client) UndoActions(ctx context.Context, tokens ...UndoToken) ([]UndoRes, error) { + return parallel.MapContext(ctx, runtime.NumCPU(), tokens, func(ctx context.Context, token UndoToken) (UndoRes, error) { + if time.Unix(token.ValidUntil, 0).Before(time.Now()) { + return UndoRes{}, ErrUndoTokenExpired + } + + var res UndoRes + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(token).SetResult(&res).Post("/mail/v4/undoactions") + }); err != nil { + return UndoRes{}, err + } + + return res, nil + }) +} diff --git a/undo_types.go b/undo_types.go new file mode 100644 index 0000000..a7c0901 --- /dev/null +++ b/undo_types.go @@ -0,0 +1,14 @@ +package proton + +import "errors" + +var ErrUndoTokenExpired = errors.New("undo token expired") + +type UndoToken struct { + Token string + ValidUntil int64 +} + +type UndoRes struct { + Messages []Message +} diff --git a/unlock.go b/unlock.go new file mode 100644 index 0000000..731bb87 --- /dev/null +++ b/unlock.go @@ -0,0 +1,32 @@ +package proton + +import ( + "fmt" + "runtime" + + "github.com/ProtonMail/gopenpgp/v2/crypto" + "github.com/bradenaw/juniper/parallel" +) + +func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRing, map[string]*crypto.KeyRing, error) { + userKR, err := user.Keys.Unlock(saltedKeyPass, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to unlock user keys: %w", err) + } + + addrKRs := make(map[string]*crypto.KeyRing) + + for idx, addrKR := range parallel.Map(runtime.NumCPU(), addresses, func(addr Address) *crypto.KeyRing { + return addr.Keys.TryUnlock(saltedKeyPass, userKR) + }) { + if addrKR != nil { + addrKRs[addresses[idx].ID] = addrKR + } + } + + if len(addrKRs) == 0 { + return nil, nil, fmt.Errorf("failed to unlock any address keys") + } + + return userKR, addrKRs, nil +} diff --git a/user.go b/user.go new file mode 100644 index 0000000..17d061a --- /dev/null +++ b/user.go @@ -0,0 +1,47 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) GetUser(ctx context.Context) (User, error) { + var res struct { + User User + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/core/v4/users") + }); err != nil { + return User{}, err + } + + return res.User, nil +} + +func (c *Client) SendVerificationCode(ctx context.Context, req SendVerificationCodeReq) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).Post("/core/v4/users/code") + }) +} + +func (c *Client) CreateUser(ctx context.Context, req CreateUserReq) (User, error) { + var res struct { + User User + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).Post("/core/v4/users") + }); err != nil { + return User{}, err + } + + return res.User, nil +} + +func (c *Client) GetUsernameAvailable(ctx context.Context, username string) error { + return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetQueryParam("Name", username).Get("/core/v4/users/available") + }) +} diff --git a/user_types.go b/user_types.go new file mode 100644 index 0000000..1c64da5 --- /dev/null +++ b/user_types.go @@ -0,0 +1,58 @@ +package proton + +type User struct { + ID string + Name string + DisplayName string + Email string + Keys Keys + + UsedSpace int + MaxSpace int + MaxUpload int + + Credit int + Currency string +} + +type TokenType string + +const ( + EmailTokenType TokenType = "email" + SMSTokenType TokenType = "sms" +) + +type SendVerificationCodeReq struct { + Username string + Type TokenType + Destination TokenDestination +} + +type TokenDestination struct { + Address string + Phone string +} + +type UserType int + +const ( + MailUserType UserType = iota + 1 + VPNUserType +) + +type CreateUserReq struct { + Username string + Email string `json:",omitempty"` + Phone string `json:",omitempty"` + Token string + TokenType TokenType + Type UserType + Auth CreateUserAuth +} + +type CreateUserAuth struct { + Version int + ModulusID string + Salt string + Verifier string +} diff --git a/volume.go b/volume.go new file mode 100644 index 0000000..7625190 --- /dev/null +++ b/volume.go @@ -0,0 +1,21 @@ +package proton + +import ( + "context" + + "github.com/go-resty/resty/v2" +) + +func (c *Client) ListVolumes(ctx context.Context) ([]Volume, error) { + var res struct { + Volumes []Volume + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetResult(&res).Get("/drive/volumes") + }); err != nil { + return nil, err + } + + return res.Volumes, nil +}