diff --git a/auth_test.go b/auth_test.go index 771d634..a2fd727 100644 --- a/auth_test.go +++ b/auth_test.go @@ -16,7 +16,7 @@ func TestAuth(t *testing.T) { s := server.New() defer s.Close() - _, _, err := s.CreateUser("user", []byte("password")) + _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) m := proton.New( @@ -26,14 +26,14 @@ func TestAuth(t *testing.T) { defer m.Close() // Create one session. - c1, auth1, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + c1, auth1, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) // Revoke all other sessions. require.NoError(t, c1.AuthRevokeAll(context.Background())) // Create another session. - c2, _, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + c2, _, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) // There should be two sessions. @@ -61,7 +61,7 @@ func TestAuth_Refresh(t *testing.T) { defer s.Close() // Create a user on the server. - userID, _, err := s.CreateUser("user", []byte("password")) + userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) // The auth is valid for 4 seconds. @@ -74,7 +74,7 @@ func TestAuth_Refresh(t *testing.T) { defer m.Close() // Create one session for the user. - c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) @@ -85,7 +85,7 @@ func TestAuth_Refresh(t *testing.T) { { user, err := c.GetUser(context.Background()) require.NoError(t, err) - require.Equal(t, "username", user.Name) + require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) } @@ -96,7 +96,7 @@ func TestAuth_Refresh(t *testing.T) { { user, err := c.GetUser(context.Background()) require.NoError(t, err) - require.Equal(t, "username", user.Name) + require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) } } @@ -106,7 +106,7 @@ func TestAuth_Refresh_Multi(t *testing.T) { defer s.Close() // Create a user on the server. - userID, _, err := s.CreateUser("user", []byte("password")) + userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) // The auth is valid for 4 seconds. @@ -118,7 +118,7 @@ func TestAuth_Refresh_Multi(t *testing.T) { ) defer m.Close() - c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) @@ -128,7 +128,7 @@ func TestAuth_Refresh_Multi(t *testing.T) { parallel.Do(runtime.NumCPU(), 100, func(idx int) { user, err := c.GetUser(context.Background()) require.NoError(t, err) - require.Equal(t, "username", user.Name) + require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) }) @@ -139,7 +139,7 @@ func TestAuth_Refresh_Multi(t *testing.T) { parallel.Do(runtime.NumCPU(), 100, func(idx int) { user, err := c.GetUser(context.Background()) require.NoError(t, err) - require.Equal(t, "username", user.Name) + require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) }) } @@ -149,7 +149,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) { defer s.Close() // Create a user on the server. - userID, _, err := s.CreateUser("user", []byte("password")) + userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) m := proton.New( @@ -159,7 +159,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) { defer m.Close() // Create one session for the user. - c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password")) + c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) @@ -172,7 +172,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) { { user, err := c.GetUser(context.Background()) require.NoError(t, err) - require.Equal(t, "username", user.Name) + require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) } diff --git a/dialer.go b/dialer.go deleted file mode 100644 index 66c8eb1..0000000 --- a/dialer.go +++ /dev/null @@ -1,311 +0,0 @@ -package proton - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "sync" -) - -func InsecureTransport() *http.Transport { - return &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } -} - -// NetCtl can be used to control whether a dialer can dial, and whether the resulting -// connection can read or write. -type NetCtl struct { - canDial atomicBool - dialLimit atomicUint64 - - canRead atomicBool - readLimit atomicUint64 - - canWrite atomicBool - writeLimit atomicUint64 - - onDial []func(net.Conn) - onRead []func([]byte) - onWrite []func([]byte) - - lock sync.Mutex -} - -// NewNetCtl returns a new NetCtl with all fields set to true. -func NewNetCtl() *NetCtl { - return &NetCtl{ - canDial: atomicBool{b32(true)}, - canRead: atomicBool{b32(true)}, - canWrite: atomicBool{b32(true)}, - } -} - -// SetCanDial sets whether the dialer can dial. -func (c *NetCtl) SetCanDial(canDial bool) { - c.canDial.Store(canDial) -} - -// SetDialLimit sets the maximum number of times dialers using this controller can dial. -func (c *NetCtl) SetDialLimit(limit uint64) { - c.dialLimit.Store(limit) -} - -// SetCanRead sets whether the connection can read. -func (c *NetCtl) SetCanRead(canRead bool) { - c.canRead.Store(canRead) -} - -// SetReadLimit sets the maximum number of bytes that can be read. -func (c *NetCtl) SetReadLimit(limit uint64) { - c.readLimit.Store(limit) -} - -// SetCanWrite sets whether the connection can write. -func (c *NetCtl) SetCanWrite(canWrite bool) { - c.canWrite.Store(canWrite) -} - -// SetWriteLimit sets the maximum number of bytes that can be written. -func (c *NetCtl) SetWriteLimit(limit uint64) { - c.writeLimit.Store(limit) -} - -// OnDial adds a callback that is called with the created connection when a dial is successful. -func (c *NetCtl) OnDial(f func(net.Conn)) { - c.lock.Lock() - defer c.lock.Unlock() - - c.onDial = append(c.onDial, f) -} - -// OnRead adds a callback that is called with the read bytes when a read is successful. -func (c *NetCtl) OnRead(f func([]byte)) { - c.lock.Lock() - defer c.lock.Unlock() - - c.onRead = append(c.onRead, f) -} - -// OnWrite adds a callback that is called with the written bytes when a write is successful. -func (c *NetCtl) OnWrite(f func([]byte)) { - c.lock.Lock() - defer c.lock.Unlock() - - c.onWrite = append(c.onWrite, f) -} - -// Disable is equivalent to disallowing dial, read and write. -func (c *NetCtl) Disable() { - c.SetCanDial(false) - c.SetCanRead(false) - c.SetCanWrite(false) -} - -// Enable is equivalent to allowing dial, read and write. -func (c *NetCtl) Enable() { - c.SetCanDial(true) - c.SetCanRead(true) - c.SetCanWrite(true) -} - -// Conn is a wrapper around net.Conn that can be used to control whether a connection can read or write. -type Conn struct { - net.Conn - - ctl *NetCtl - - readLimiter *readLimiter - writeLimiter *writeLimiter -} - -// Read reads from the wrapped connection, but only if the controller allows it. -func (c *Conn) Read(b []byte) (int, error) { - if !c.ctl.canRead.Load() { - return 0, errors.New("cannot read") - } - - n, err := c.readLimiter.read(c.Conn, b) - if err != nil { - return n, err - } - - for _, f := range c.ctl.onRead { - f(b[:n]) - } - - return n, err -} - -// Write writes to the wrapped connection, but only if the controller allows it. -func (c *Conn) Write(b []byte) (int, error) { - if !c.ctl.canWrite.Load() { - return 0, errors.New("cannot write") - } - - n, err := c.writeLimiter.write(c.Conn, b) - if err != nil { - return n, err - } - - for _, f := range c.ctl.onWrite { - f(b[:n]) - } - - return n, err -} - -// Dialer performs network dialing, but only if the controller allows it. -type Dialer struct { - ctl *NetCtl - - netDialer *net.Dialer - tlsDialer *tls.Dialer - tlsConfig *tls.Config - - readLimiter *readLimiter - writeLimiter *writeLimiter - - dialCount atomicUint64 -} - -// NewDialer returns a new dialer using the given net controller. -// It optionally uses a provided tls config. -func NewDialer(ctl *NetCtl, tlsConfig *tls.Config) *Dialer { - return &Dialer{ - ctl: ctl, - - netDialer: &net.Dialer{}, - tlsDialer: &tls.Dialer{Config: tlsConfig}, - tlsConfig: tlsConfig, - - readLimiter: newReadLimiter(ctl), - writeLimiter: newWriteLimiter(ctl), - - dialCount: atomicUint64{0}, - } -} - -// DialContext dials a network connection, but only if the controller allows it. -func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - return d.dialWithDialer(ctx, network, addr, d.netDialer) -} - -// DialTLSContext dials a TLS network connection, but only if the controller allows it. -func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { - return d.dialWithDialer(ctx, network, addr, d.tlsDialer) -} - -// dialWithDialer dials a network connection using the given dialer, but only if the controller allows it. -func (d *Dialer) dialWithDialer(ctx context.Context, network, addr string, dialer dialer) (net.Conn, error) { - if !d.ctl.canDial.Load() { - return nil, errors.New("cannot dial") - } - - if limit := d.ctl.dialLimit.Load(); limit > 0 && d.dialCount.Load() >= limit { - return nil, errors.New("dial limit reached") - } else { - d.dialCount.Add(1) - } - - conn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - - d.ctl.lock.Lock() - defer d.ctl.lock.Unlock() - - for _, f := range d.ctl.onDial { - f(conn) - } - - return &Conn{ - Conn: conn, - ctl: d.ctl, - - readLimiter: d.readLimiter, - writeLimiter: d.writeLimiter, - }, nil -} - -// GetRoundTripper returns a new http.RoundTripper that uses the dialer. -func (d *Dialer) GetRoundTripper() http.RoundTripper { - return &http.Transport{ - DialContext: d.DialContext, - DialTLSContext: d.DialTLSContext, - TLSClientConfig: d.tlsConfig, - } -} - -type dialer interface { - DialContext(ctx context.Context, network, addr string) (net.Conn, error) -} - -type readLimiter struct { - ctl *NetCtl - - count atomicUint64 -} - -// newReadLimiter returns a new io.Reader that reads from r, but only up to limit bytes. -func newReadLimiter(ctl *NetCtl) *readLimiter { - return &readLimiter{ - ctl: ctl, - } -} - -func (limiter *readLimiter) read(r io.Reader, b []byte) (int, error) { - if limit := limiter.ctl.readLimit.Load(); limit > 0 && limiter.count.Load() >= limit { - return 0, fmt.Errorf("refusing to read: read limit reached") - } - - n, err := r.Read(b) - if err != nil { - return n, err - } - - if limit := limiter.ctl.readLimit.Load(); limit > 0 { - if new := limiter.count.Add(uint64(n)); new >= limit { - return 0, fmt.Errorf("read failed: read limit reached") - } - } - - return n, err -} - -type writeLimiter struct { - ctl *NetCtl - - count atomicUint64 -} - -// newWriteLimiter returns a new io.Writer that writes to w, but only up to limit bytes. -func newWriteLimiter(ctl *NetCtl) *writeLimiter { - return &writeLimiter{ - ctl: ctl, - } -} - -func (limiter *writeLimiter) write(w io.Writer, b []byte) (int, error) { - if limit := limiter.ctl.writeLimit.Load(); limit > 0 && limiter.count.Load() >= limit { - return 0, fmt.Errorf("refusing to write: write limit reached") - } - - n, err := w.Write(b) - if err != nil { - return n, err - } - - if limit := limiter.ctl.writeLimit.Load(); limit > 0 { - if new := limiter.count.Add(uint64(n)); new >= limit { - return 0, fmt.Errorf("write failed: write limit reached") - } - } - - return n, err -} diff --git a/event_test.go b/event_test.go index 3bf6e14..087f798 100644 --- a/event_test.go +++ b/event_test.go @@ -22,13 +22,13 @@ func TestEventStreamer(t *testing.T) { proton.WithTransport(proton.InsecureTransport()), ) - _, _, err := s.CreateUser("user", []byte("password")) + _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) - c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password")) + c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) - createTestMessages(t, c, "password", 10) + createTestMessages(t, c, "pass", 10) latestEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) @@ -53,7 +53,7 @@ func TestEventStreamer(t *testing.T) { c.Close() // Create a new client and perform some actions with it to generate more events. - cc, _, err := m.NewClientWithLogin(ctx, "username", []byte("password")) + cc, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer cc.Close() diff --git a/manager_status_test.go b/manager_status_test.go index 9184ef2..ac22a78 100644 --- a/manager_status_test.go +++ b/manager_status_test.go @@ -15,11 +15,11 @@ func TestStatus(t *testing.T) { s := server.New() defer s.Close() - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) var ( @@ -39,7 +39,7 @@ func TestStatus(t *testing.T) { require.Zero(t, called) // Now we simulate a network failure. - netCtl.Disable() + ctl.Disable() // This should fail. require.Error(t, m.Ping(context.Background())) @@ -49,7 +49,7 @@ func TestStatus(t *testing.T) { require.Equal(t, proton.StatusDown, status) // Now we simulate a network restoration. - netCtl.Enable() + ctl.Enable() // This should succeed. require.NoError(t, m.Ping(context.Background())) @@ -63,11 +63,11 @@ func TestStatus_NoDial(t *testing.T) { s := server.New() defer s.Close() - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) var ( @@ -81,7 +81,7 @@ func TestStatus_NoDial(t *testing.T) { }) // Disable dialing. - netCtl.SetCanDial(false) + ctl.SetCanDial(false) // This should fail. require.Error(t, m.Ping(context.Background())) @@ -95,11 +95,11 @@ func TestStatus_NoRead(t *testing.T) { s := server.New() defer s.Close() - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) var ( @@ -113,7 +113,7 @@ func TestStatus_NoRead(t *testing.T) { }) // Disable reading. - netCtl.SetCanRead(false) + ctl.SetCanRead(false) // This should fail. require.Error(t, m.Ping(context.Background())) @@ -127,11 +127,11 @@ func TestStatus_NoWrite(t *testing.T) { s := server.New() defer s.Close() - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) var ( @@ -145,7 +145,7 @@ func TestStatus_NoWrite(t *testing.T) { }) // Disable writing. - netCtl.SetCanWrite(false) + ctl.SetCanWrite(false) // This should fail. require.Error(t, m.Ping(context.Background())) @@ -162,17 +162,17 @@ func TestStatus_NoReadExistingConn(t *testing.T) { _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() var dialed int - netCtl.OnDial(func(net.Conn) { + ctl.OnDial(func(net.Conn) { dialed++ }) m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) // This should succeed. @@ -184,13 +184,10 @@ func TestStatus_NoReadExistingConn(t *testing.T) { require.Equal(t, 1, dialed) // Disable reading on the existing connection. - netCtl.SetCanRead(false) + ctl.SetCanRead(false) // This should fail because we won't be able to read the response. require.Error(t, getErr(c.GetUser(context.Background()))) - - // We should still have dialed once; the connection should have been reused. - require.Equal(t, 1, dialed) } func TestStatus_NoWriteExistingConn(t *testing.T) { @@ -200,17 +197,17 @@ func TestStatus_NoWriteExistingConn(t *testing.T) { _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() var dialed int - netCtl.OnDial(func(net.Conn) { + ctl.OnDial(func(net.Conn) { dialed++ }) m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), proton.WithRetryCount(0), ) @@ -223,7 +220,7 @@ func TestStatus_NoWriteExistingConn(t *testing.T) { require.Equal(t, 1, dialed) // Disable reading on the existing connection. - netCtl.SetCanWrite(false) + ctl.SetCanWrite(false) // This should fail because we won't be able to write the request. require.Error(t, c.LabelMessages(context.Background(), []string{"messageID"}, proton.TrashLabel)) @@ -240,7 +237,7 @@ func TestStatus_ContextCancel(t *testing.T) { var called int - m.AddStatusObserver(func(val proton.Status) { + m.AddStatusObserver(func(proton.Status) { called++ }) @@ -263,7 +260,7 @@ func TestStatus_ContextTimeout(t *testing.T) { var called int - m.AddStatusObserver(func(val proton.Status) { + m.AddStatusObserver(func(proton.Status) { called++ }) diff --git a/manager_test.go b/manager_test.go index f599d51..269b8cd 100644 --- a/manager_test.go +++ b/manager_test.go @@ -4,16 +4,15 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net" "net/http" "net/http/httptest" + "strconv" "testing" "time" "github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api/server" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,17 +20,17 @@ func TestConnectionReuse(t *testing.T) { s := server.New() defer s.Close() - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() var dialed int - netCtl.OnDial(func(net.Conn) { + ctl.OnDial(func(net.Conn) { dialed++ }) m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) // This should succeed; the resulting connection should be reused. @@ -69,89 +68,70 @@ func TestAuthRefresh(t *testing.T) { } func TestHandleTooManyRequests(t *testing.T) { - var numCalls int + // Create a server with a rate limit of 1 request per second. + s := server.New(server.WithRateLimit(1, time.Second)) + defer s.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - numCalls++ + var calls []server.Call - if numCalls < 5 { - w.WriteHeader(http.StatusTooManyRequests) - } else { - w.WriteHeader(http.StatusOK) - } - })) - defer ts.Close() + // Watch the calls made. + s.AddCallWatcher(func(call server.Call) { + calls = append(calls, call) + }) m := proton.New( - proton.WithHostURL(ts.URL), - proton.WithRetryCount(5), + proton.WithHostURL(s.GetHostURL()), + proton.WithTransport(proton.InsecureTransport()), ) + defer m.Close() - // The call should succeed because the 5th retry should succeed (429s are retried). - c := m.NewClient("", "", "") - defer c.Close() - - if _, err := c.GetAddresses(context.Background()); err != nil { - t.Fatal("got unexpected error", err) + // Make five calls; they should all succeed, but will be rate limited. + for i := 0; i < 5; i++ { + require.NoError(t, m.Ping(context.Background())) } - // The server should be called 5 times. - // The first four calls should return 429 and the last call should return 200. - if numCalls != 5 { - t.Fatal("expected numCalls to be 5, instead got", numCalls) + // After each 429 response, we should wait at least the requested duration before making the next request. + for idx, call := range calls { + if call.Status == http.StatusTooManyRequests { + after, err := strconv.Atoi(call.ResponseHeader.Get("Retry-After")) + require.NoError(t, err) + + // The next call should be made after the requested duration. + require.True(t, calls[idx+1].Time.After(call.Time.Add(time.Duration(after)*time.Second))) + } } } -func TestHandleTooManyRequestsRetryAfter(t *testing.T) { - getDelay := func(iCal int) time.Duration { - return time.Duration(5*1<= 0 { - delay := getDelay(iRetry) - assert.False(t, currentCall.Before( - lastCall.Add(delay)), - "Delay was %v but expected to have %v", - currentCall.Sub(lastCall), - delay, - ) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if len(calls) == 0 { + w.Header().Set("Retry-After", "malformed") + w.WriteHeader(http.StatusTooManyRequests) } - iRetry++ - lastCall = currentCall - - // test defaul 10sec - if iRetry == 1 { - w.Header().Set("Retry-After", "something") - } else { - w.Header().Set("Retry-After", fmt.Sprintf("%.0f", getDelay(iRetry).Seconds())) - } - w.WriteHeader(http.StatusTooManyRequests) + calls = append(calls, time.Now()) })) defer ts.Close() - m := proton.New( - proton.WithHostURL(ts.URL), - proton.WithRetryCount(3), - ) + m := proton.New(proton.WithHostURL(ts.URL)) + defer m.Close() - c := m.NewClient("", "", "") - defer c.Close() + require.NoError(t, m.Ping(context.Background())) - _, err := c.GetAddresses(context.Background()) - require.Error(t, err) + // The first call should fail because the Retry-After header is invalid. + // The second call should succeed. + require.Len(t, calls, 2) + + // The second call should be made at least 10 seconds after the first call. + require.True(t, calls[1].After(calls[0].Add(10*time.Second))) } func TestHandleUnprocessableEntity(t *testing.T) { var numCalls int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { numCalls++ w.WriteHeader(http.StatusUnprocessableEntity) })) @@ -180,7 +160,7 @@ func TestHandleUnprocessableEntity(t *testing.T) { func TestHandleDialFailure(t *testing.T) { var numCalls int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { numCalls++ w.WriteHeader(http.StatusOK) })) @@ -210,7 +190,7 @@ func TestHandleDialFailure(t *testing.T) { func TestHandleTooManyDialFailures(t *testing.T) { var numCalls int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { numCalls++ w.WriteHeader(http.StatusOK) })) @@ -242,7 +222,7 @@ func TestHandleTooManyDialFailures(t *testing.T) { func TestRetriesWithContextTimeout(t *testing.T) { var numCalls int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { numCalls++ if numCalls < 5 { @@ -276,7 +256,7 @@ func TestRetriesWithContextTimeout(t *testing.T) { } func TestReturnErrNoConnection(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) defer ts.Close() @@ -305,7 +285,7 @@ func TestStatusCallbacks(t *testing.T) { m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) statusCh := make(chan proton.Status, 1) diff --git a/netctl.go b/netctl.go new file mode 100644 index 0000000..ee8b3f1 --- /dev/null +++ b/netctl.go @@ -0,0 +1,398 @@ +package proton + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" +) + +// InsecureTransport returns an http.Transport with InsecureSkipVerify set to true. +func InsecureTransport() *http.Transport { + return &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } +} + +// ctl can be used to control whether a dialer can dial, and whether the resulting +// connection can read or write. +type NetCtl struct { + canDial bool + dialLimit uint64 + dialCount uint64 + onDial []func(net.Conn) + dlock sync.RWMutex + + canRead bool + readLimit uint64 + readCount uint64 + readSpeed int + onRead []func([]byte) + rlock sync.RWMutex + + canWrite bool + writeLimit uint64 + writeCount uint64 + writeSpeed int + onWrite []func([]byte) + wlock sync.RWMutex + + conns []net.Conn +} + +// NewNetCtl returns a new ctl with all fields set to true. +func NewNetCtl() *NetCtl { + return &NetCtl{ + canDial: true, + canRead: true, + canWrite: true, + } +} + +// SetCanDial sets whether the dialer can dial. +func (c *NetCtl) SetCanDial(canDial bool) { + c.dlock.Lock() + defer c.dlock.Unlock() + + c.canDial = canDial +} + +// SetDialLimit sets the maximum number of times dialers using this controller can dial. +func (c *NetCtl) SetDialLimit(limit uint64) { + c.dlock.Lock() + defer c.dlock.Unlock() + + c.dialLimit = limit +} + +// SetCanRead sets whether the connection can read. +func (c *NetCtl) SetCanRead(canRead bool) { + c.dlock.Lock() + defer c.dlock.Unlock() + + for _, conn := range c.conns { + conn.Close() + } + + c.rlock.Lock() + defer c.rlock.Unlock() + + c.canRead = canRead +} + +// SetReadLimit sets the maximum number of bytes that can be read. +func (c *NetCtl) SetReadLimit(limit uint64) { + c.dlock.Lock() + defer c.dlock.Unlock() + + for _, conn := range c.conns { + conn.Close() + } + + c.rlock.Lock() + defer c.rlock.Unlock() + + c.readLimit = limit + c.readCount = 0 +} + +// SetReadSpeed sets the maximum number of bytes that can be read per second. +func (c *NetCtl) SetReadSpeed(speed int) { + c.dlock.Lock() + defer c.dlock.Unlock() + + for _, conn := range c.conns { + conn.Close() + } + + c.rlock.Lock() + defer c.rlock.Unlock() + + c.readSpeed = speed +} + +// SetCanWrite sets whether the connection can write. +func (c *NetCtl) SetCanWrite(canWrite bool) { + c.dlock.Lock() + defer c.dlock.Unlock() + + for _, conn := range c.conns { + conn.Close() + } + + c.wlock.Lock() + defer c.wlock.Unlock() + + c.canWrite = canWrite +} + +// SetWriteLimit sets the maximum number of bytes that can be written. +func (c *NetCtl) SetWriteLimit(limit uint64) { + c.dlock.Lock() + defer c.dlock.Unlock() + + for _, conn := range c.conns { + conn.Close() + } + + c.wlock.Lock() + defer c.wlock.Unlock() + + c.writeLimit = limit + c.writeCount = 0 +} + +// SetWriteSpeed sets the maximum number of bytes that can be written per second. +func (c *NetCtl) SetWriteSpeed(speed int) { + c.dlock.Lock() + defer c.dlock.Unlock() + + for _, conn := range c.conns { + conn.Close() + } + + c.wlock.Lock() + defer c.wlock.Unlock() + + c.writeSpeed = speed +} + +// OnDial adds a callback that is called with the created connection when a dial is successful. +func (c *NetCtl) OnDial(f func(net.Conn)) { + c.dlock.Lock() + defer c.dlock.Unlock() + + c.onDial = append(c.onDial, f) +} + +// OnRead adds a callback that is called with the read bytes when a read is successful. +func (c *NetCtl) OnRead(fn func([]byte)) { + c.rlock.Lock() + defer c.rlock.Unlock() + + c.onRead = append(c.onRead, fn) +} + +// OnWrite adds a callback that is called with the written bytes when a write is successful. +func (c *NetCtl) OnWrite(fn func([]byte)) { + c.wlock.Lock() + defer c.wlock.Unlock() + + c.onWrite = append(c.onWrite, fn) +} + +// Disable is equivalent to disallowing dial, read and write. +func (c *NetCtl) Disable() { + c.SetCanDial(false) + c.SetCanRead(false) + c.SetCanWrite(false) +} + +// Enable is equivalent to allowing dial, read and write. +func (c *NetCtl) Enable() { + c.SetCanDial(true) + c.SetCanRead(true) + c.SetCanWrite(true) +} + +// NewDialer returns a new dialer controlled by the ctl. +func (c *NetCtl) NewRoundTripper(tlsConfig *tls.Config) http.RoundTripper { + return &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return c.dial(ctx, &net.Dialer{}, network, addr) + }, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return c.dial(ctx, &tls.Dialer{Config: tlsConfig}, network, addr) + }, + TLSClientConfig: tlsConfig, + } +} + +// ctxDialer implements DialContext. +type ctxDialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) +} + +// dial dials using d, but only if the controller allows it. +func (c *NetCtl) dial(ctx context.Context, dialer ctxDialer, network, addr string) (net.Conn, error) { + c.dlock.Lock() + defer c.dlock.Unlock() + + if !c.canDial { + return nil, errors.New("dial failed (not allowed)") + } + + if c.dialLimit > 0 && c.dialCount >= c.dialLimit { + return nil, errors.New("dial failed (limit reached)") + } + + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + c.dialCount++ + + for _, fn := range c.onDial { + fn(conn) + } + + c.conns = append(c.conns, conn) + + return newConn(conn, c), nil +} + +// read reads from r, but only if the controller allows it. +func (c *NetCtl) read(r io.Reader, b []byte) (int, error) { + c.rlock.Lock() + defer c.rlock.Unlock() + + if !c.canRead { + return 0, errors.New("read failed (not allowed)") + } + + if c.readLimit > 0 && c.readCount >= c.readLimit { + return 0, errors.New("read failed (limit reached)") + } + + var rem uint64 + + if c.readLimit > 0 && c.readLimit-c.readCount < uint64(len(b)) { + rem = c.readLimit - c.readCount + } else { + rem = uint64(len(b)) + } + + c.rlock.Unlock() + n, err := newSlowReader(r, c.readSpeed).Read(b[:rem]) + c.rlock.Lock() + + c.readCount += uint64(n) + + for _, fn := range c.onRead { + fn(b[:n]) + } + + return n, err +} + +// write writes to w, but only if the controller allows it. +func (c *NetCtl) write(w io.Writer, b []byte) (int, error) { + c.wlock.Lock() + defer c.wlock.Unlock() + + if !c.canWrite { + return 0, errors.New("write failed (not allowed)") + } + + if c.writeLimit > 0 && c.writeCount >= c.writeLimit { + return 0, errors.New("write failed (limit exceeded)") + } + + var rem uint64 + + if c.writeLimit > 0 && c.writeLimit-c.writeCount < uint64(len(b)) { + rem = c.writeLimit - c.writeCount + } else { + rem = uint64(len(b)) + } + + c.wlock.Unlock() + n, err := newSlowWriter(w, c.writeSpeed).Write(b[:rem]) + c.wlock.Lock() + + c.writeCount += uint64(n) + + for _, fn := range c.onWrite { + fn(b[:n]) + } + + if uint64(n) < rem { + return n, fmt.Errorf("write incomplete (limit reached)") + } + + return n, err +} + +// conn is a wrapper around net.conn that can be used to control whether a connection can read or write. +type conn struct { + net.Conn + + ctl *NetCtl +} + +func newConn(c net.Conn, ctl *NetCtl) *conn { + return &conn{ + Conn: c, + ctl: ctl, + } +} + +// Read reads from the wrapped connection, but only if the controller allows it. +func (c *conn) Read(b []byte) (int, error) { + return c.ctl.read(c.Conn, b) +} + +// Write writes to the wrapped connection, but only if the controller allows it. +func (c *conn) Write(b []byte) (int, error) { + return c.ctl.write(c.Conn, b) +} + +// slowReader is an io.Reader that reads at a fixed rate. +type slowReader struct { + r io.Reader + + // bytesPerSec is the number of bytes to read per second. + bytesPerSec int +} + +func newSlowReader(r io.Reader, bytesPerSec int) *slowReader { + return &slowReader{ + r: r, + bytesPerSec: bytesPerSec, + } +} + +func (r *slowReader) Read(b []byte) (int, error) { + start := time.Now() + + n, err := r.r.Read(b) + + if r.bytesPerSec > 0 { + time.Sleep(time.Until(start.Add(time.Duration(n*r.bytesPerSec) * time.Second))) + } + + return n, err +} + +// slowWriter is an io.Writer that writes at a fixed rate. +type slowWriter struct { + w io.Writer + + // bytesPerSec is the number of bytes to write per second. + bytesPerSec int +} + +func newSlowWriter(w io.Writer, bytesPerSec int) *slowWriter { + return &slowWriter{ + w: w, + bytesPerSec: bytesPerSec, + } +} + +func (w *slowWriter) Write(b []byte) (int, error) { + start := time.Now() + + n, err := w.w.Write(b) + + if w.bytesPerSec > 0 { + time.Sleep(time.Until(start.Add(time.Duration(n*w.bytesPerSec) * time.Second))) + } + + return n, err +} diff --git a/dialer_test.go b/netctl_test.go similarity index 81% rename from dialer_test.go rename to netctl_test.go index 6ef32c2..ee9f01f 100644 --- a/dialer_test.go +++ b/netctl_test.go @@ -14,7 +14,7 @@ import ( func TestNetCtl_ReadLimit(t *testing.T) { // Create a test http server that writes 100 bytes. // Including the header, this is 217 bytes (100 bytes + 117 bytes). - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write(make([]byte, 100)); err != nil { t.Fatal(err) } @@ -22,16 +22,16 @@ func TestNetCtl_ReadLimit(t *testing.T) { defer ts.Close() // Create a new net controller. - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() + + // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. + ctl.SetReadLimit(300) // Create a new http client with the dialer. client := &http.Client{ - Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}), } - // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. - netCtl.SetReadLimit(300) - // This should succeed. if resp, err := client.Get(ts.URL); err != nil { t.Fatal(err) @@ -47,7 +47,7 @@ func TestNetCtl_ReadLimit(t *testing.T) { func TestNetCtl_WriteLimit(t *testing.T) { // Create a test http server that reads the given body. - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { if _, err := io.ReadAll(r.Body); err != nil { t.Fatal(err) } @@ -55,16 +55,16 @@ func TestNetCtl_WriteLimit(t *testing.T) { defer ts.Close() // Create a new net controller. - netCtl := proton.NewNetCtl() + ctl := proton.NewNetCtl() + + // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. + ctl.SetWriteLimit(300) // Create a new http client with the dialer. client := &http.Client{ - Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(), + Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}), } - // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. - netCtl.SetWriteLimit(300) - // This should succeed. if resp, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err != nil { t.Fatal(err) diff --git a/response.go b/response.go index 4017fa5..d344b67 100644 --- a/response.go +++ b/response.go @@ -62,37 +62,31 @@ func updateTime(_ *resty.Client, res *resty.Response) error { return nil } +// nolint:gosec func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) { - if res.StatusCode() == http.StatusTooManyRequests { - if after := res.Header().Get("Retry-After"); after != "" { - l := logrus.WithFields(logrus.Fields{ - "pkg": "go-proton-api", - "statusCode": res.StatusCode(), - "url": res.Request.URL, - "verb": res.Request.Method, - }) - - seconds, err := strconv.Atoi(after) - if err != nil { - l.WithField("after", after).WithError(err).Warning( - "Cannot convert Retry-After to a number, continue with default 10 second cooldown.", - ) - seconds = 10 - } - - // To avoid spikes when all clients retry at the same time, we add some random wait. - seconds += rand.Intn(10) //nolint:gosec // It is OK to use weak random number generator here. - l = l.WithField("delay", seconds) - - // Maximum retry time in client is is one minute. But - // here wait times can be longer e.g. high API load - l.Warn("Delay the retry after http response") - return time.Duration(seconds) * time.Second, nil - } + // 0 and no error means default behaviour which is exponential backoff with jitter. + if res.StatusCode() != http.StatusTooManyRequests { + return 0, nil } - // 0 and no error means default behaviour which is exponential backoff with jitter. - return 0, nil + // Parse the Retry-After header, or fallback to 10 seconds. + after, err := strconv.Atoi(res.Header().Get("Retry-After")) + if err != nil { + after = 10 + } + + // Add some jitter to the delay. + after += rand.Intn(10) + + logrus.WithFields(logrus.Fields{ + "pkg": "go-proton-api", + "status": res.StatusCode(), + "url": res.Request.URL, + "method": res.Request.Method, + "after": after, + }).Warn("Too many requests, retrying after delay") + + return time.Duration(after) * time.Second, nil } func catchTooManyRequests(res *resty.Response, _ error) bool { diff --git a/server/backend/crypto_fast.go b/server/backend/crypto_fast.go new file mode 100644 index 0000000..9e020a5 --- /dev/null +++ b/server/backend/crypto_fast.go @@ -0,0 +1,25 @@ +package backend + +import "github.com/ProtonMail/gopenpgp/v2/crypto" + +var preCompKey *crypto.Key + +func init() { + key, err := crypto.GenerateKey("name", "email", "rsa", 1024) + if err != nil { + panic(err) + } + + preCompKey = key +} + +// FastGenerateKey is a fast version of GenerateKey that uses a pre-computed key. +// This is useful for testing but is incredibly insecure. +func FastGenerateKey(_, _ string, passphrase []byte, _ string, _ int) (string, error) { + encKey, err := preCompKey.Lock(passphrase) + if err != nil { + return "", err + } + + return encKey.Armor() +} diff --git a/server/call.go b/server/call.go index 9bd78bc..45eff9f 100644 --- a/server/call.go +++ b/server/call.go @@ -3,6 +3,7 @@ package server import ( "net/http" "net/url" + "time" ) type Call struct { @@ -10,6 +11,9 @@ type Call struct { Method string Status int + Time time.Time + Duration time.Duration + RequestHeader http.Header RequestBody []byte diff --git a/server/init_test.go b/server/init_test.go index 79b446c..c885b17 100644 --- a/server/init_test.go +++ b/server/init_test.go @@ -1,22 +1,7 @@ package server -import ( - "github.com/ProtonMail/go-proton-api/server/backend" - "github.com/ProtonMail/gopenpgp/v2/crypto" -) +import "github.com/ProtonMail/go-proton-api/server/backend" func init() { - key, err := crypto.GenerateKey("name", "email", "rsa", 1024) - if err != nil { - panic(err) - } - - backend.GenerateKey = func(_, _ string, passphrase []byte, _ string, _ int) (string, error) { - encKey, err := key.Lock(passphrase) - if err != nil { - return "", err - } - - return encKey.Armor() - } + backend.GenerateKey = backend.FastGenerateKey } diff --git a/server/rate_limit.go b/server/rate_limit.go new file mode 100644 index 0000000..29bb59e --- /dev/null +++ b/server/rate_limit.go @@ -0,0 +1,51 @@ +package server + +import ( + "sync" + "time" +) + +// rateLimiter is a rate limiter for the server. +// If more than limit requests are made in the time window, the server will return 429. +type rateLimiter struct { + // limit is the rate limit to apply to the server. + limit int + + // window is the window in which to apply the rate limit. + window time.Duration + + // nextReset is the time at which the rate limit will reset. + nextReset time.Time + + // count is the number of calls made to the server. + count int + + // countLock is a mutex for the callCount. + countLock sync.Mutex +} + +func newRateLimiter(limit int, window time.Duration) *rateLimiter { + return &rateLimiter{ + limit: limit, + window: window, + } +} + +// exceeded checks the rate limit and returns how long to wait before the next request. +func (r *rateLimiter) exceeded() time.Duration { + r.countLock.Lock() + defer r.countLock.Unlock() + + if time.Now().After(r.nextReset) { + r.count = 0 + r.nextReset = time.Now().Add(r.window) + } + + r.count++ + + if r.count > r.limit { + return time.Until(r.nextReset) + } + + return 0 +} diff --git a/server/router.go b/server/router.go index c83cadb..9a61ff3 100644 --- a/server/router.go +++ b/server/router.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "time" @@ -20,6 +21,7 @@ func initRouter(s *Server) { s.r.Use( s.requireValidAppVersion(), s.setSessionCookie(), + s.applyRateLimit(), ) if core := s.r.Group("/core/v4"); core != nil { @@ -171,8 +173,23 @@ func (s *Server) setSessionCookie() gin.HandlerFunc { } } +func (s *Server) applyRateLimit() gin.HandlerFunc { + return func(c *gin.Context) { + if s.rateLimit == nil { + return + } + + if wait := s.rateLimit.exceeded(); wait > 0 { + c.Header("Retry-After", strconv.Itoa(int(wait.Seconds()))) + c.AbortWithStatus(http.StatusTooManyRequests) + } + } +} + func (s *Server) logCalls() gin.HandlerFunc { return func(c *gin.Context) { + start := time.Now() + req, err := io.ReadAll(c.Request.Body) if err != nil { panic(err) @@ -199,6 +216,9 @@ func (s *Server) logCalls() gin.HandlerFunc { Method: c.Request.Method, Status: c.Writer.Status(), + Time: start, + Duration: time.Since(start), + RequestHeader: c.Request.Header, RequestBody: req, diff --git a/server/server.go b/server/server.go index b6ac468..3e797c8 100644 --- a/server/server.go +++ b/server/server.go @@ -47,6 +47,9 @@ type Server struct { // offline is whether to pretend the server is offline and return 5xx errors. offline bool + + // rateLimit is the rate limiter for the server. + rateLimit *rateLimiter } func New(opts ...Option) *Server { @@ -69,7 +72,7 @@ func (s *Server) GetProxyURL() string { return s.s.URL + "/proxy" } -// GetDomain returns the domain of the server. +// GetDomain returns the domain of the server (e.g. "proton.local"). func (s *Server) GetDomain() string { return s.domain } diff --git a/server/server_builder.go b/server/server_builder.go index 2d28a18..fda65cd 100644 --- a/server/server_builder.go +++ b/server/server_builder.go @@ -12,11 +12,12 @@ import ( ) type serverBuilder struct { - withTLS bool - domain string - logger io.Writer - origin string - cacher AuthCacher + withTLS bool + domain string + logger io.Writer + origin string + cacher AuthCacher + rateLimiter *rateLimiter } func newServerBuilder() *serverBuilder { @@ -46,6 +47,7 @@ func (builder *serverBuilder) build() *Server { domain: builder.domain, proxyOrigin: builder.origin, authCacher: builder.cacher, + rateLimit: builder.rateLimiter, } if builder.withTLS { @@ -143,3 +145,19 @@ type withAuthCache struct { func (opt withAuthCache) config(builder *serverBuilder) { builder.cacher = opt.cacher } + +func WithRateLimit(limit int, window time.Duration) Option { + return &withRateLimit{ + limit: limit, + window: window, + } +} + +type withRateLimit struct { + limit int + window time.Duration +} + +func (opt withRateLimit) config(builder *serverBuilder) { + builder.rateLimiter = newRateLimiter(opt.limit, opt.window) +} diff --git a/server/server_test.go b/server/server_test.go index 237eca3..0f2350c 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -77,7 +77,7 @@ func TestServer_Ping(t *testing.T) { m := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) var status proton.Status @@ -1313,7 +1313,7 @@ func TestServer_Messages_Fetch(t *testing.T) { mm := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) defer mm.Close() @@ -1355,7 +1355,7 @@ func TestServer_Messages_Status(t *testing.T) { mm := proton.New( proton.WithHostURL(s.GetHostURL()), - proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()), + proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) defer mm.Close() diff --git a/unlock.go b/unlock.go index 731bb87..faa6842 100644 --- a/unlock.go +++ b/unlock.go @@ -12,6 +12,8 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi userKR, err := user.Keys.Unlock(saltedKeyPass, nil) if err != nil { return nil, nil, fmt.Errorf("failed to unlock user keys: %w", err) + } else if userKR.CountDecryptionEntities() == 0 { + return nil, nil, fmt.Errorf("failed to unlock any user keys") } addrKRs := make(map[string]*crypto.KeyRing) @@ -19,9 +21,13 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi for idx, addrKR := range parallel.Map(runtime.NumCPU(), addresses, func(addr Address) *crypto.KeyRing { return addr.Keys.TryUnlock(saltedKeyPass, userKR) }) { - if addrKR != nil { - addrKRs[addresses[idx].ID] = addrKR + if addrKR.CountDecryptionEntities() == 0 { + continue + } else if addrKR == nil { + continue } + + addrKRs[addresses[idx].ID] = addrKR } if len(addrKRs) == 0 {