diff --git a/message_draft_types.go b/message_draft_types.go index ab47a63..6530b1e 100644 --- a/message_draft_types.go +++ b/message_draft_types.go @@ -36,3 +36,8 @@ type CreateDraftReq struct { ParentID string `json:",omitempty"` Action CreateDraftAction } + +type UpdateDraftReq struct { + Message DraftTemplate + AttachmentKeyPackets []string +} diff --git a/message_send.go b/message_send.go index 9bdc9ff..7f0e28f 100644 --- a/message_send.go +++ b/message_send.go @@ -20,6 +20,20 @@ func (c *Client) CreateDraft(ctx context.Context, req CreateDraftReq) (Message, return res.Message, nil } +func (c *Client) UpdateDraft(ctx context.Context, draftID string, req UpdateDraftReq) (Message, error) { + var res struct { + Message Message + } + + if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(req).SetResult(&res).Put("/mail/v4/messages/" + draftID) + }); err != nil { + return Message{}, err + } + + return res.Message, nil +} + func (c *Client) SendDraft(ctx context.Context, draftID string, req SendDraftReq) (Message, error) { var res struct { Sent Message diff --git a/server/backend/api.go b/server/backend/api.go index 99839c3..db2dfe2 100644 --- a/server/backend/api.go +++ b/server/backend/api.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "errors" "fmt" - "net/mail" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-proton-api" @@ -420,30 +419,50 @@ func (b *Backend) DeleteMessage(userID, messageID string) error { }) } -func (b *Backend) CreateDraft( - userID, addrID string, - subject string, - sender *mail.Address, - toList, ccList, bccList []*mail.Address, - armBody string, - mimeType rfc822.MIMEType, - externalID string, -) (proton.Message, error) { +func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate) (proton.Message, error) { return withAcc(b, userID, func(acc *account) (proton.Message, error) { return withMessages(b, func(messages map[string]*message) (proton.Message, error) { - msg := newMessage(addrID, subject, sender, toList, ccList, bccList, armBody, mimeType, externalID) + return withLabels(b, func(labels map[string]*label) (proton.Message, error) { + msg := newMessageFromTemplate(addrID, draft) - messages[msg.messageID] = msg + // Drafts automatically get the sysLabel "Drafts". + msg.addLabel(proton.DraftsLabel, labels) - updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID}) - if err != nil { - return proton.Message{}, err - } + messages[msg.messageID] = msg - acc.messageIDs = append(acc.messageIDs, msg.messageID) - acc.updateIDs = append(acc.updateIDs, updateID) + updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID}) + if err != nil { + return proton.Message{}, err + } - return msg.toMessage(nil), nil + acc.messageIDs = append(acc.messageIDs, msg.messageID) + acc.updateIDs = append(acc.updateIDs, updateID) + + return msg.toMessage(nil), nil + }) + }) + }) +} + +func (b *Backend) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) (proton.Message, error) { + return withAcc(b, userID, func(acc *account) (proton.Message, error) { + return withMessages(b, func(messages map[string]*message) (proton.Message, error) { + return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) { + if _, ok := messages[draftID]; !ok { + return proton.Message{}, fmt.Errorf("message %q not found", draftID) + } + + messages[draftID].applyChanges(changes) + + updateID, err := b.newUpdate(&messageUpdated{messageID: draftID}) + if err != nil { + return proton.Message{}, err + } + + acc.updateIDs = append(acc.updateIDs, updateID) + + return messages[draftID].toMessage(atts), nil + }) }) }) } diff --git a/server/backend/backend.go b/server/backend/backend.go index 7ce433f..539cd03 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -362,27 +362,6 @@ func (b *Backend) CreateMessage( }) } -func (b *Backend) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) (string, error) { - return withAcc(b, userID, func(acc *account) (string, error) { - return withMessages(b, func(messages map[string]*message) (string, error) { - if _, ok := messages[draftID]; !ok { - return "", fmt.Errorf("message %q not found", draftID) - } - - messages[draftID].applyChanges(changes) - - updateID, err := b.newUpdate(&messageUpdated{messageID: draftID}) - if err != nil { - return "", err - } - - acc.updateIDs = append(acc.updateIDs, updateID) - - return draftID, nil - }) - }) -} - func (b *Backend) Encrypt(userID, addrID, decBody string) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) { pubKey, err := acc.addresses[addrID].keys[0].getPubKey() diff --git a/server/backend/message.go b/server/backend/message.go index de63abb..0927230 100644 --- a/server/backend/message.go +++ b/server/backend/message.go @@ -59,6 +59,24 @@ func newMessage( } } +func newMessageFromTemplate(addrID string, template proton.DraftTemplate) *message { + return &message{ + messageID: uuid.NewString(), + externalID: template.ExternalID, + addrID: addrID, + sysLabel: pointer(""), + + subject: template.Subject, + sender: template.Sender, + toList: template.ToList, + ccList: template.CCList, + bccList: template.BCCList, + + armBody: template.Body, + mimeType: template.MIMEType, + } +} + func (msg *message) toMessage(att map[string]*attachment) proton.Message { return proton.Message{ MessageMetadata: msg.toMetadata(), @@ -167,8 +185,7 @@ func (msg *message) getParsedHeaders() proton.Headers { // applyChanges will apply non-nil field from passed message. // -// NOTE: This is not feature complete. It might panic on non-implemented -// changes. +// NOTE: This is not feature complete. It might panic on non-implemented changes. func (msg *message) applyChanges(changes proton.DraftTemplate) { if changes.Subject != "" { msg.subject = changes.Subject diff --git a/server/messages.go b/server/messages.go index d0c0f76..d515f37 100644 --- a/server/messages.go +++ b/server/messages.go @@ -18,14 +18,11 @@ import ( func (s *Server) handleGetMailMessages() gin.HandlerFunc { return func(c *gin.Context) { - filter := proton.MessageFilter{ - ID: c.QueryArray("ID"), - } s.getMailMessages( c, mustParseInt(c.DefaultQuery("Page", "0")), mustParseInt(c.DefaultQuery("PageSize", "100")), - filter, + proton.MessageFilter{ID: c.QueryArray("ID")}, ) } } @@ -88,18 +85,7 @@ func (s *Server) postMailMessages(c *gin.Context) { return } - message, err := s.b.CreateDraft( - c.GetString("UserID"), - addrID, - req.Message.Subject, - req.Message.Sender, - req.Message.ToList, - req.Message.CCList, - req.Message.BCCList, - req.Message.Body, - req.Message.MIMEType, - req.Message.ExternalID, - ) + message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return @@ -165,6 +151,27 @@ func (s *Server) handlePostMailMessage() gin.HandlerFunc { } } +func (s *Server) handlePutMailMessage() gin.HandlerFunc { + return func(c *gin.Context) { + var req proton.UpdateDraftReq + + if err := c.BindJSON(&req); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + message, err := s.b.UpdateDraft(c.GetString("UserID"), c.Param("messageID"), req.Message) + if err != nil { + c.AbortWithStatus(http.StatusUnprocessableEntity) + return + } + + c.JSON(http.StatusOK, gin.H{ + "Message": message, + }) + } +} + func (s *Server) handlePutMailMessagesRead() gin.HandlerFunc { return func(c *gin.Context) { var req proton.MessageActionReq diff --git a/server/router.go b/server/router.go index 20d9006..5243577 100644 --- a/server/router.go +++ b/server/router.go @@ -82,6 +82,7 @@ func initRouter(s *Server) { messages.GET("/ids", s.handleGetMailMessageIDs()) messages.GET("/:messageID", s.handleGetMailMessage()) messages.POST("/:messageID", s.handlePostMailMessage()) + messages.PUT("/:messageID", s.handlePutMailMessage()) messages.PUT("/read", s.handlePutMailMessagesRead()) messages.PUT("/unread", s.handlePutMailMessagesUnread()) messages.PUT("/label", s.handlePutMailMessagesLabel()) diff --git a/server/server.go b/server/server.go index 1e0e119..a05362d 100644 --- a/server/server.go +++ b/server/server.go @@ -150,12 +150,6 @@ func (s *Server) UnlabelMessage(userID, msgID, labelID string) error { return s.b.UnlabelMessages(userID, labelID, msgID) } -func (s *Server) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) error { - _, err := s.b.UpdateDraft(userID, draftID, changes) - - return err -} - func (s *Server) SetAuthLife(authLife time.Duration) { s.b.SetAuthLife(authLife) } diff --git a/server/server_test.go b/server/server_test.go index 984060e..fdb439f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -588,6 +588,7 @@ func TestServer_CreateMessage(t *testing.T) { require.Equal(t, addresses[0].ID, draft.AddressID) require.Equal(t, "My subject", draft.Subject) require.Equal(t, &mail.Address{Address: "email@pm.me"}, draft.Sender) + require.ElementsMatch(t, []string{proton.AllMailLabel, proton.AllDraftsLabel, proton.DraftsLabel}, draft.LabelIDs) }) }) } @@ -598,6 +599,7 @@ func TestServer_UpdateDraft(t *testing.T) { addresses, err := c.GetAddresses(ctx) require.NoError(t, err) + // Create the draft. draft, err := c.CreateDraft(ctx, proton.CreateDraftReq{ Message: proton.DraftTemplate{ Subject: "My subject", @@ -606,33 +608,31 @@ func TestServer_UpdateDraft(t *testing.T) { }, }) require.NoError(t, err) - require.Equal(t, addresses[0].ID, draft.AddressID) require.Equal(t, "My subject", draft.Subject) require.Equal(t, &mail.Address{Address: "email@pm.me"}, draft.Sender) - user, err := c.GetUser(ctx) - require.NoError(t, err) - + // Create an event stream to watch for an update event. fromEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) - // Draft updated on server side. - _, err = s.b.UpdateDraft(user.ID, draft.ID, proton.DraftTemplate{ - Subject: "Edited subject", - ToList: []*mail.Address{{Address: "edited@pm.me"}}, - Body: "Edited body", + // Update the draft subject/to-list. + msg, err := c.UpdateDraft(ctx, draft.ID, proton.UpdateDraftReq{ + Message: proton.DraftTemplate{ + Subject: "Edited subject", + ToList: []*mail.Address{{Address: "edited@pm.me"}}, + }, }) require.NoError(t, err) + require.Equal(t, "Edited subject", msg.Subject) - var updated *proton.MessageMetadata - + // We should eventually get an update event. require.Eventually(t, func() bool { event := <-eventCh - if len(event.Messages) != 1 { + if len(event.Messages) < 1 { return false } @@ -644,14 +644,12 @@ func TestServer_UpdateDraft(t *testing.T) { return false } - updated = &event.Messages[0].Message + require.Equal(t, draft.ID, event.Messages[0].ID) + require.Equal(t, "Edited subject", event.Messages[0].Message.Subject) + require.Equal(t, []*mail.Address{{Address: "edited@pm.me"}}, event.Messages[0].Message.ToList) return true }, 5*time.Second, time.Millisecond*100) - - require.Equal(t, draft.ID, updated.ID) - require.Equal(t, "Edited subject", updated.Subject) - require.Equal(t, []*mail.Address{{Address: "edited@pm.me"}}, updated.ToList) }) }) }