feat(GODT-2500): Add panic handlers everywhere.

This commit is contained in:
Jakub
2023-03-23 13:48:13 +01:00
committed by cuthix
parent 20351a206b
commit 1d5a7231e2
21 changed files with 144 additions and 37 deletions

View File

@@ -3,6 +3,8 @@ package proton
import (
"bytes"
"context"
"github.com/ProtonMail/gluon/queue"
"github.com/bradenaw/juniper/parallel"
)
@@ -56,10 +58,11 @@ func (SequentialScheduler) Schedule(ctx context.Context, attachmentIDs []string,
}
type ParallelScheduler struct {
workers int
workers int
panicHandler queue.PanicHandler
}
func NewParallelScheduler(workers int) *ParallelScheduler {
func NewParallelScheduler(workers int, panicHandler queue.PanicHandler) *ParallelScheduler {
if workers == 0 {
workers = 1
}
@@ -67,6 +70,12 @@ func NewParallelScheduler(workers int) *ParallelScheduler {
return &ParallelScheduler{workers: workers}
}
func (p *ParallelScheduler) handlePanic() {
if p.panicHandler != nil {
p.panicHandler.HandlePanic()
}
}
func (p ParallelScheduler) Schedule(ctx context.Context, attachmentIDs []string, storageProvider AttachmentAllocator, downloader func(context.Context, string, *bytes.Buffer) error) ([]*bytes.Buffer, error) {
// If we have less attachments than the maximum works, reduce worker count to match attachment count.
workers := p.workers
@@ -75,6 +84,8 @@ func (p ParallelScheduler) Schedule(ctx context.Context, attachmentIDs []string,
}
return parallel.MapContext(ctx, workers, attachmentIDs, func(ctx context.Context, id string) (*bytes.Buffer, error) {
defer p.handlePanic()
buffer := storageProvider.NewBuffer()
if err := downloader(ctx, id, buffer); err != nil {
return nil, err

View File

@@ -46,7 +46,7 @@ func (c *Client) GetAllCalendarEvents(ctx context.Context, calendarID string, fi
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]CalendarEvent, error) {
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]CalendarEvent, error) {
return c.GetCalendarEvents(ctx, calendarID, page, pageSize, filter)
})
}

View File

@@ -72,7 +72,7 @@ func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) {
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]Contact, error) {
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]Contact, error) {
return c.GetContacts(ctx, page, pageSize)
})
}
@@ -101,7 +101,7 @@ func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]Conta
return nil, err
}
return fetchPaged(ctx, total, maxPageSize, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) {
return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) {
return c.GetContactEmails(ctx, email, page, pageSize)
})
}

View File

@@ -60,9 +60,11 @@ func (c *Client) NewEventStream(ctx context.Context, period, jitter time.Duratio
eventCh := make(chan Event)
go func() {
defer c.m.handlePanic()
defer close(eventCh)
ticker := NewTicker(period, jitter)
ticker := NewTicker(period, jitter, c.m.panicHandler)
defer ticker.Stop()
for {

View File

@@ -1,7 +1,10 @@
package proton
import "github.com/ProtonMail/gluon/queue"
type Future[T any] struct {
resCh chan res[T]
resCh chan res[T]
panicHandler queue.PanicHandler
}
type res[T any] struct {
@@ -9,26 +12,40 @@ type res[T any] struct {
err error
}
func NewFuture[T any](fn func() (T, error)) *Future[T] {
func NewFuture[T any](panicHandler queue.PanicHandler, fn func() (T, error)) *Future[T] {
resCh := make(chan res[T])
job := &Future[T]{
resCh: resCh,
panicHandler: panicHandler,
}
go func() {
defer job.handlePanic()
val, err := fn()
resCh <- res[T]{val: val, err: err}
}()
return &Future[T]{resCh: resCh}
return job
}
func (job *Future[T]) Then(fn func(T, error)) {
go func() {
defer job.handlePanic()
res := <-job.resCh
fn(res.val, res.err)
}()
}
func (job *Future[T]) handlePanic() {
if job.panicHandler != nil {
job.panicHandler.HandlePanic()
}
}
func (job *Future[T]) Get() (T, error) {
res := <-job.resCh
@@ -36,15 +53,16 @@ func (job *Future[T]) Get() (T, error) {
}
type Group[T any] struct {
futures []*Future[T]
futures []*Future[T]
panicHandler queue.PanicHandler
}
func NewGroup[T any]() *Group[T] {
return &Group[T]{}
func NewGroup[T any](panicHandler queue.PanicHandler) *Group[T] {
return &Group[T]{panicHandler: panicHandler}
}
func (group *Group[T]) Add(fn func() (T, error)) {
group.futures = append(group.futures, NewFuture(fn))
group.futures = append(group.futures, NewFuture(group.panicHandler, fn))
}
func (group *Group[T]) Result() ([]T, error) {

View File

@@ -5,13 +5,14 @@ import (
"testing"
"time"
"github.com/ProtonMail/gluon/queue"
"github.com/stretchr/testify/require"
)
func TestFuture(t *testing.T) {
resCh := make(chan int)
NewFuture(func() (int, error) {
NewFuture(queue.NoopPanicHandler{}, func() (int, error) {
return 42, nil
}).Then(func(res int, err error) {
resCh <- res
@@ -21,7 +22,7 @@ func TestFuture(t *testing.T) {
}
func TestGroup(t *testing.T) {
group := NewGroup[int]()
group := NewGroup[int](queue.NoopPanicHandler{})
for i := 0; i < 10; i++ {
i := i

2
go.mod
View File

@@ -63,3 +63,5 @@ require (
google.golang.org/genproto v0.0.0-20230221151758-ace64dc21148 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace github.com/ProtonMail/gluon => /home/dev/gopath18/src/gluon

View File

@@ -6,6 +6,7 @@ import (
"runtime"
"testing"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/iterator"
"github.com/bradenaw/juniper/stream"
@@ -31,7 +32,7 @@ func createTestMessages(t *testing.T, c *proton.Client, pass string, count int)
keyPass, err := salt.SaltForKey([]byte(pass), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
_, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{})
require.NoError(t, err)
req := iterator.Collect(iterator.Map(iterator.Counter(count), func(i int) proton.ImportReq {

View File

@@ -6,6 +6,7 @@ import (
"net"
"sync"
"github.com/ProtonMail/gluon/queue"
"github.com/go-resty/resty/v2"
)
@@ -19,6 +20,8 @@ type Manager struct {
errHandlers map[Code][]Handler
verifyProofs bool
panicHandler queue.PanicHandler
}
func New(opts ...Option) *Manager {
@@ -128,3 +131,9 @@ func (m *Manager) onConnUp() {
observer(m.status)
}
}
func (m *Manager) handlePanic() {
if m.panicHandler != nil {
m.panicHandler.HandlePanic()
}
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"time"
"github.com/ProtonMail/gluon/queue"
"github.com/go-resty/resty/v2"
)
@@ -25,6 +26,7 @@ type managerBuilder struct {
retryCount int
logger resty.Logger
debug bool
panicHandler queue.PanicHandler
}
func newManagerBuilder() *managerBuilder {
@@ -37,6 +39,7 @@ func newManagerBuilder() *managerBuilder {
retryCount: 3,
logger: nil,
debug: false,
panicHandler: queue.NoopPanicHandler{},
}
}
@@ -47,6 +50,8 @@ func (builder *managerBuilder) build() *Manager {
errHandlers: make(map[Code][]Handler),
verifyProofs: builder.verifyProofs,
panicHandler: builder.panicHandler,
}
// Set the API host.

View File

@@ -61,7 +61,7 @@ func (c *Client) GetMessageMetadata(ctx context.Context, filter MessageFilter) (
return nil, err
}
return fetchPaged(ctx, count, maxPageSize, func(ctx context.Context, page, pageSize int) ([]MessageMetadata, error) {
return fetchPaged(ctx, count, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]MessageMetadata, error) {
return c.GetMessageMetadataPage(ctx, page, pageSize, filter)
})
}
@@ -87,6 +87,8 @@ func (c *Client) DeleteMessage(ctx context.Context, messageIDs ...string) error
pages := xslices.Chunk(messageIDs, maxPageSize)
return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error {
defer c.m.handlePanic()
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/delete")
})
@@ -97,6 +99,8 @@ func (c *Client) MarkMessagesRead(ctx context.Context, messageIDs ...string) err
pages := xslices.Chunk(messageIDs, maxPageSize)
return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error {
defer c.m.handlePanic()
return c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/read")
})
@@ -107,6 +111,8 @@ func (c *Client) MarkMessagesUnread(ctx context.Context, messageIDs ...string) e
pages := xslices.Chunk(messageIDs, maxPageSize)
return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error {
defer c.m.handlePanic()
req := MessageActionReq{IDs: pages[idx]}
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
@@ -125,6 +131,8 @@ func (c *Client) LabelMessages(ctx context.Context, messageIDs []string, labelID
runtime.NumCPU(),
xslices.Chunk(messageIDs, maxPageSize),
func(ctx context.Context, messageIDs []string) (LabelMessagesRes, error) {
defer c.m.handlePanic()
var res LabelMessagesRes
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {
@@ -164,6 +172,8 @@ func (c *Client) UnlabelMessages(ctx context.Context, messageIDs []string, label
runtime.NumCPU(),
xslices.Chunk(messageIDs, maxPageSize),
func(ctx context.Context, messageIDs []string) (LabelMessagesRes, error) {
defer c.m.handlePanic()
var res LabelMessagesRes
if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) {

View File

@@ -45,6 +45,8 @@ func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, wor
workers,
buffer,
func(ctx context.Context, req []ImportReq) (stream.Stream[ImportRes], error) {
defer c.m.handlePanic()
res, err := c.importMessages(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to import messages: %w", err)

View File

@@ -3,6 +3,7 @@ package proton
import (
"net/http"
"github.com/ProtonMail/gluon/queue"
"github.com/go-resty/resty/v2"
)
@@ -122,3 +123,17 @@ type withDebug struct {
func (opt withDebug) config(builder *managerBuilder) {
builder.debug = opt.debug
}
func WithPanicHandler(panicHandler queue.PanicHandler) Option {
return &withPanicHandler{
panicHandler: panicHandler,
}
}
type withPanicHandler struct {
panicHandler queue.PanicHandler
}
func (opt withPanicHandler) config(builder *managerBuilder) {
builder.panicHandler = opt.panicHandler
}

View File

@@ -13,7 +13,7 @@ const maxPageSize = 150
func fetchPaged[T any](
ctx context.Context,
total, pageSize int,
total, pageSize int, c *Client,
fn func(ctx context.Context, page, pageSize int) ([]T, error),
) ([]T, error) {
return stream.Collect(ctx, stream.Flatten(parallel.MapStream(
@@ -22,6 +22,8 @@ func fetchPaged[T any](
runtime.NumCPU(),
runtime.NumCPU(),
func(ctx context.Context, page int) (stream.Stream[T], error) {
defer c.m.handlePanic()
values, err := fn(ctx, page, pageSize)
if err != nil {
return nil, err

19
pool.go
View File

@@ -14,23 +14,26 @@ var ErrJobCancelled = errors.New("job cancelled by surrounding context")
// Pool is a worker pool that handles input of type In and returns results of type Out.
type Pool[In comparable, Out any] struct {
queue *queue.QueuedChannel[*job[In, Out]]
wg sync.WaitGroup
queue *queue.QueuedChannel[*job[In, Out]]
wg sync.WaitGroup
panicHandler queue.PanicHandler
}
// doneFunc must be called to free up pool resources.
type doneFunc func()
// New returns a new pool.
func NewPool[In comparable, Out any](size int, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
func NewPool[In comparable, Out any](size int, panicHandler queue.PanicHandler, work func(context.Context, In) (Out, error)) *Pool[In, Out] {
pool := &Pool[In, Out]{
queue: queue.NewQueuedChannel[*job[In, Out]](0, 0),
queue: queue.NewQueuedChannel[*job[In, Out]](0, 0, panicHandler),
}
for i := 0; i < size; i++ {
pool.wg.Add(1)
go func() {
defer pool.handlePanic()
defer pool.wg.Done()
for job := range pool.queue.GetChannel() {
@@ -55,6 +58,12 @@ func NewPool[In comparable, Out any](size int, work func(context.Context, In) (O
return pool
}
func (pool *Pool[In, Out]) handlePanic() {
if pool.panicHandler != nil {
pool.panicHandler.HandlePanic()
}
}
// Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred.
func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(int, In, Out, error) error) error {
ctx, cancel := context.WithCancel(ctx)
@@ -72,6 +81,8 @@ func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(int,
wg.Add(1)
go func(index int) {
defer pool.handlePanic()
defer wg.Done()
job, done, err := pool.newJob(ctx, req)

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/ProtonMail/gluon/queue"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -153,7 +154,7 @@ func TestPool_ProcessAll(t *testing.T) {
}
func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] {
return NewPool(workers, func(ctx context.Context, req int) (int, error) {
return NewPool(workers, queue.NoopPanicHandler{}, func(ctx context.Context, req int) (int, error) {
if len(delay) > 0 {
time.Sleep(delay[0])
}
@@ -163,7 +164,7 @@ func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] {
}
func newDoublerWithError(workers int) *Pool[int, int] {
return NewPool(workers, func(ctx context.Context, req int) (int, error) {
return NewPool(workers, queue.NoopPanicHandler{}, func(ctx context.Context, req int) (int, error) {
if req%2 == 0 {
return 0, errors.New("oops")
}

View File

@@ -483,11 +483,12 @@ func (s *Server) importBody(
headerDate := header.Get("Date")
if len(headerDate) != 0 {
if d, err := mail.ParseDate(headerDate); err != nil {
d, err := mail.ParseDate(headerDate)
if err != nil {
return "", err
} else {
date = d
}
date = d
}
// NOTE: Importing without sender adds empty sender on API side

View File

@@ -126,7 +126,7 @@ func (opt withTLS) config(builder *serverBuilder) {
builder.withTLS = opt.withTLS
}
// withDomain controls the domain of the server.
// WithDomain controls the domain of the server.
func WithDomain(domain string) Option {
return &withDomain{
domain: domain,

View File

@@ -19,6 +19,7 @@ import (
"github.com/bradenaw/juniper/parallel"
"github.com/Masterminds/semver/v3"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
@@ -612,7 +613,7 @@ func TestServer_CreateMessage(t *testing.T) {
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
_, addrKRs, err := proton.Unlock(user, addr, pass, queue.NoopPanicHandler{})
require.NoError(t, err)
draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{
@@ -648,7 +649,7 @@ func TestServer_UpdateDraft(t *testing.T) {
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
_, addrKRs, err := proton.Unlock(user, addr, pass, queue.NoopPanicHandler{})
require.NoError(t, err)
// Create the draft.
@@ -724,7 +725,7 @@ func TestServer_SendMessage(t *testing.T) {
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
_, addrKRs, err := proton.Unlock(user, addr, pass, queue.NoopPanicHandler{})
require.NoError(t, err)
draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{
@@ -807,7 +808,7 @@ func TestServer_Import(t *testing.T) {
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
_, addrKRs, err := proton.Unlock(user, addr, pass, queue.NoopPanicHandler{})
require.NoError(t, err)
res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, 1)
@@ -1026,7 +1027,7 @@ func TestServer_Labels(t *testing.T) {
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
_, addrKRs, err := proton.Unlock(user, addr, pass, queue.NoopPanicHandler{})
require.NoError(t, err)
for _, tt := range tests {
@@ -1168,7 +1169,7 @@ func TestServer_Import_FlagsAndLabels(t *testing.T) {
pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, pass)
_, addrKRs, err := proton.Unlock(user, addr, pass, queue.NoopPanicHandler{})
require.NoError(t, err)
for _, tt := range tests {
@@ -1845,7 +1846,7 @@ func withMessages(ctx context.Context, t *testing.T, c *proton.Client, pass stri
keyPass, err := salt.SaltForKey([]byte(pass), user.Keys.Primary().ID)
require.NoError(t, err)
_, addrKRs, err := proton.Unlock(user, addr, keyPass)
_, addrKRs, err := proton.Unlock(user, addr, keyPass, queue.NoopPanicHandler{})
require.NoError(t, err)
fn(xslices.Map(importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, count), func(res proton.ImportRes) string {

View File

@@ -3,6 +3,8 @@ package proton
import (
"math/rand"
"time"
"github.com/ProtonMail/gluon/queue"
)
type Ticker struct {
@@ -14,7 +16,7 @@ type Ticker struct {
// NewTicker returns a new ticker that ticks at a random time between period and period+jitter.
// It can be stopped by closing calling Stop().
func NewTicker(period, jitter time.Duration) *Ticker {
func NewTicker(period, jitter time.Duration, panicHandler queue.PanicHandler) *Ticker {
t := &Ticker{
C: make(chan time.Time),
stopCh: make(chan struct{}),
@@ -22,6 +24,12 @@ func NewTicker(period, jitter time.Duration) *Ticker {
}
go func() {
defer func() {
if panicHandler != nil {
panicHandler.HandlePanic()
}
}()
defer close(t.doneCh)
for {

View File

@@ -4,11 +4,12 @@ import (
"fmt"
"runtime"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/bradenaw/juniper/parallel"
)
func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRing, map[string]*crypto.KeyRing, error) {
func Unlock(user User, addresses []Address, saltedKeyPass []byte, panicHandler queue.PanicHandler) (*crypto.KeyRing, map[string]*crypto.KeyRing, error) {
userKR, err := user.Keys.Unlock(saltedKeyPass, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to unlock user keys: %w", err)
@@ -19,6 +20,12 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi
addrKRs := make(map[string]*crypto.KeyRing)
for idx, addrKR := range parallel.Map(runtime.NumCPU(), addresses, func(addr Address) *crypto.KeyRing {
defer func() {
if panicHandler != nil {
panicHandler.HandlePanic()
}
}()
return addr.Keys.TryUnlock(saltedKeyPass, userKR)
}) {
if addrKR == nil {