refactor: NetCtl, transport, dialer, rate limiting

This commit is contained in:
James Houlahan
2022-12-12 19:23:24 +01:00
committed by James
parent 779a2ee672
commit fd06b106da
17 changed files with 662 additions and 492 deletions

View File

@@ -16,7 +16,7 @@ func TestAuth(t *testing.T) {
s := server.New()
defer s.Close()
_, _, err := s.CreateUser("user", []byte("password"))
_, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
m := proton.New(
@@ -26,14 +26,14 @@ func TestAuth(t *testing.T) {
defer m.Close()
// Create one session.
c1, auth1, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
c1, auth1, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
// Revoke all other sessions.
require.NoError(t, c1.AuthRevokeAll(context.Background()))
// Create another session.
c2, _, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
c2, _, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
// There should be two sessions.
@@ -61,7 +61,7 @@ func TestAuth_Refresh(t *testing.T) {
defer s.Close()
// Create a user on the server.
userID, _, err := s.CreateUser("user", []byte("password"))
userID, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
// The auth is valid for 4 seconds.
@@ -74,7 +74,7 @@ func TestAuth_Refresh(t *testing.T) {
defer m.Close()
// Create one session for the user.
c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
require.Equal(t, userID, auth.UserID)
@@ -85,7 +85,7 @@ func TestAuth_Refresh(t *testing.T) {
{
user, err := c.GetUser(context.Background())
require.NoError(t, err)
require.Equal(t, "username", user.Name)
require.Equal(t, "user", user.Name)
require.Equal(t, userID, user.ID)
}
@@ -96,7 +96,7 @@ func TestAuth_Refresh(t *testing.T) {
{
user, err := c.GetUser(context.Background())
require.NoError(t, err)
require.Equal(t, "username", user.Name)
require.Equal(t, "user", user.Name)
require.Equal(t, userID, user.ID)
}
}
@@ -106,7 +106,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
defer s.Close()
// Create a user on the server.
userID, _, err := s.CreateUser("user", []byte("password"))
userID, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
// The auth is valid for 4 seconds.
@@ -118,7 +118,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
)
defer m.Close()
c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
require.Equal(t, userID, auth.UserID)
@@ -128,7 +128,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
parallel.Do(runtime.NumCPU(), 100, func(idx int) {
user, err := c.GetUser(context.Background())
require.NoError(t, err)
require.Equal(t, "username", user.Name)
require.Equal(t, "user", user.Name)
require.Equal(t, userID, user.ID)
})
@@ -139,7 +139,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
parallel.Do(runtime.NumCPU(), 100, func(idx int) {
user, err := c.GetUser(context.Background())
require.NoError(t, err)
require.Equal(t, "username", user.Name)
require.Equal(t, "user", user.Name)
require.Equal(t, userID, user.ID)
})
}
@@ -149,7 +149,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
defer s.Close()
// Create a user on the server.
userID, _, err := s.CreateUser("user", []byte("password"))
userID, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
m := proton.New(
@@ -159,7 +159,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
defer m.Close()
// Create one session for the user.
c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
require.NoError(t, err)
require.Equal(t, userID, auth.UserID)
@@ -172,7 +172,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
{
user, err := c.GetUser(context.Background())
require.NoError(t, err)
require.Equal(t, "username", user.Name)
require.Equal(t, "user", user.Name)
require.Equal(t, userID, user.ID)
}

311
dialer.go
View File

@@ -1,311 +0,0 @@
package proton
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
)
func InsecureTransport() *http.Transport {
return &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
}
// NetCtl can be used to control whether a dialer can dial, and whether the resulting
// connection can read or write.
type NetCtl struct {
canDial atomicBool
dialLimit atomicUint64
canRead atomicBool
readLimit atomicUint64
canWrite atomicBool
writeLimit atomicUint64
onDial []func(net.Conn)
onRead []func([]byte)
onWrite []func([]byte)
lock sync.Mutex
}
// NewNetCtl returns a new NetCtl with all fields set to true.
func NewNetCtl() *NetCtl {
return &NetCtl{
canDial: atomicBool{b32(true)},
canRead: atomicBool{b32(true)},
canWrite: atomicBool{b32(true)},
}
}
// SetCanDial sets whether the dialer can dial.
func (c *NetCtl) SetCanDial(canDial bool) {
c.canDial.Store(canDial)
}
// SetDialLimit sets the maximum number of times dialers using this controller can dial.
func (c *NetCtl) SetDialLimit(limit uint64) {
c.dialLimit.Store(limit)
}
// SetCanRead sets whether the connection can read.
func (c *NetCtl) SetCanRead(canRead bool) {
c.canRead.Store(canRead)
}
// SetReadLimit sets the maximum number of bytes that can be read.
func (c *NetCtl) SetReadLimit(limit uint64) {
c.readLimit.Store(limit)
}
// SetCanWrite sets whether the connection can write.
func (c *NetCtl) SetCanWrite(canWrite bool) {
c.canWrite.Store(canWrite)
}
// SetWriteLimit sets the maximum number of bytes that can be written.
func (c *NetCtl) SetWriteLimit(limit uint64) {
c.writeLimit.Store(limit)
}
// OnDial adds a callback that is called with the created connection when a dial is successful.
func (c *NetCtl) OnDial(f func(net.Conn)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onDial = append(c.onDial, f)
}
// OnRead adds a callback that is called with the read bytes when a read is successful.
func (c *NetCtl) OnRead(f func([]byte)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onRead = append(c.onRead, f)
}
// OnWrite adds a callback that is called with the written bytes when a write is successful.
func (c *NetCtl) OnWrite(f func([]byte)) {
c.lock.Lock()
defer c.lock.Unlock()
c.onWrite = append(c.onWrite, f)
}
// Disable is equivalent to disallowing dial, read and write.
func (c *NetCtl) Disable() {
c.SetCanDial(false)
c.SetCanRead(false)
c.SetCanWrite(false)
}
// Enable is equivalent to allowing dial, read and write.
func (c *NetCtl) Enable() {
c.SetCanDial(true)
c.SetCanRead(true)
c.SetCanWrite(true)
}
// Conn is a wrapper around net.Conn that can be used to control whether a connection can read or write.
type Conn struct {
net.Conn
ctl *NetCtl
readLimiter *readLimiter
writeLimiter *writeLimiter
}
// Read reads from the wrapped connection, but only if the controller allows it.
func (c *Conn) Read(b []byte) (int, error) {
if !c.ctl.canRead.Load() {
return 0, errors.New("cannot read")
}
n, err := c.readLimiter.read(c.Conn, b)
if err != nil {
return n, err
}
for _, f := range c.ctl.onRead {
f(b[:n])
}
return n, err
}
// Write writes to the wrapped connection, but only if the controller allows it.
func (c *Conn) Write(b []byte) (int, error) {
if !c.ctl.canWrite.Load() {
return 0, errors.New("cannot write")
}
n, err := c.writeLimiter.write(c.Conn, b)
if err != nil {
return n, err
}
for _, f := range c.ctl.onWrite {
f(b[:n])
}
return n, err
}
// Dialer performs network dialing, but only if the controller allows it.
type Dialer struct {
ctl *NetCtl
netDialer *net.Dialer
tlsDialer *tls.Dialer
tlsConfig *tls.Config
readLimiter *readLimiter
writeLimiter *writeLimiter
dialCount atomicUint64
}
// NewDialer returns a new dialer using the given net controller.
// It optionally uses a provided tls config.
func NewDialer(ctl *NetCtl, tlsConfig *tls.Config) *Dialer {
return &Dialer{
ctl: ctl,
netDialer: &net.Dialer{},
tlsDialer: &tls.Dialer{Config: tlsConfig},
tlsConfig: tlsConfig,
readLimiter: newReadLimiter(ctl),
writeLimiter: newWriteLimiter(ctl),
dialCount: atomicUint64{0},
}
}
// DialContext dials a network connection, but only if the controller allows it.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dialWithDialer(ctx, network, addr, d.netDialer)
}
// DialTLSContext dials a TLS network connection, but only if the controller allows it.
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dialWithDialer(ctx, network, addr, d.tlsDialer)
}
// dialWithDialer dials a network connection using the given dialer, but only if the controller allows it.
func (d *Dialer) dialWithDialer(ctx context.Context, network, addr string, dialer dialer) (net.Conn, error) {
if !d.ctl.canDial.Load() {
return nil, errors.New("cannot dial")
}
if limit := d.ctl.dialLimit.Load(); limit > 0 && d.dialCount.Load() >= limit {
return nil, errors.New("dial limit reached")
} else {
d.dialCount.Add(1)
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
d.ctl.lock.Lock()
defer d.ctl.lock.Unlock()
for _, f := range d.ctl.onDial {
f(conn)
}
return &Conn{
Conn: conn,
ctl: d.ctl,
readLimiter: d.readLimiter,
writeLimiter: d.writeLimiter,
}, nil
}
// GetRoundTripper returns a new http.RoundTripper that uses the dialer.
func (d *Dialer) GetRoundTripper() http.RoundTripper {
return &http.Transport{
DialContext: d.DialContext,
DialTLSContext: d.DialTLSContext,
TLSClientConfig: d.tlsConfig,
}
}
type dialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
type readLimiter struct {
ctl *NetCtl
count atomicUint64
}
// newReadLimiter returns a new io.Reader that reads from r, but only up to limit bytes.
func newReadLimiter(ctl *NetCtl) *readLimiter {
return &readLimiter{
ctl: ctl,
}
}
func (limiter *readLimiter) read(r io.Reader, b []byte) (int, error) {
if limit := limiter.ctl.readLimit.Load(); limit > 0 && limiter.count.Load() >= limit {
return 0, fmt.Errorf("refusing to read: read limit reached")
}
n, err := r.Read(b)
if err != nil {
return n, err
}
if limit := limiter.ctl.readLimit.Load(); limit > 0 {
if new := limiter.count.Add(uint64(n)); new >= limit {
return 0, fmt.Errorf("read failed: read limit reached")
}
}
return n, err
}
type writeLimiter struct {
ctl *NetCtl
count atomicUint64
}
// newWriteLimiter returns a new io.Writer that writes to w, but only up to limit bytes.
func newWriteLimiter(ctl *NetCtl) *writeLimiter {
return &writeLimiter{
ctl: ctl,
}
}
func (limiter *writeLimiter) write(w io.Writer, b []byte) (int, error) {
if limit := limiter.ctl.writeLimit.Load(); limit > 0 && limiter.count.Load() >= limit {
return 0, fmt.Errorf("refusing to write: write limit reached")
}
n, err := w.Write(b)
if err != nil {
return n, err
}
if limit := limiter.ctl.writeLimit.Load(); limit > 0 {
if new := limiter.count.Add(uint64(n)); new >= limit {
return 0, fmt.Errorf("write failed: write limit reached")
}
}
return n, err
}

View File

@@ -22,13 +22,13 @@ func TestEventStreamer(t *testing.T) {
proton.WithTransport(proton.InsecureTransport()),
)
_, _, err := s.CreateUser("user", []byte("password"))
_, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))
c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
createTestMessages(t, c, "password", 10)
createTestMessages(t, c, "pass", 10)
latestEventID, err := c.GetLatestEventID(ctx)
require.NoError(t, err)
@@ -53,7 +53,7 @@ func TestEventStreamer(t *testing.T) {
c.Close()
// Create a new client and perform some actions with it to generate more events.
cc, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))
cc, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
require.NoError(t, err)
defer cc.Close()

View File

@@ -15,11 +15,11 @@ func TestStatus(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
var (
@@ -39,7 +39,7 @@ func TestStatus(t *testing.T) {
require.Zero(t, called)
// Now we simulate a network failure.
netCtl.Disable()
ctl.Disable()
// This should fail.
require.Error(t, m.Ping(context.Background()))
@@ -49,7 +49,7 @@ func TestStatus(t *testing.T) {
require.Equal(t, proton.StatusDown, status)
// Now we simulate a network restoration.
netCtl.Enable()
ctl.Enable()
// This should succeed.
require.NoError(t, m.Ping(context.Background()))
@@ -63,11 +63,11 @@ func TestStatus_NoDial(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
var (
@@ -81,7 +81,7 @@ func TestStatus_NoDial(t *testing.T) {
})
// Disable dialing.
netCtl.SetCanDial(false)
ctl.SetCanDial(false)
// This should fail.
require.Error(t, m.Ping(context.Background()))
@@ -95,11 +95,11 @@ func TestStatus_NoRead(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
var (
@@ -113,7 +113,7 @@ func TestStatus_NoRead(t *testing.T) {
})
// Disable reading.
netCtl.SetCanRead(false)
ctl.SetCanRead(false)
// This should fail.
require.Error(t, m.Ping(context.Background()))
@@ -127,11 +127,11 @@ func TestStatus_NoWrite(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
var (
@@ -145,7 +145,7 @@ func TestStatus_NoWrite(t *testing.T) {
})
// Disable writing.
netCtl.SetCanWrite(false)
ctl.SetCanWrite(false)
// This should fail.
require.Error(t, m.Ping(context.Background()))
@@ -162,17 +162,17 @@ func TestStatus_NoReadExistingConn(t *testing.T) {
_, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
var dialed int
netCtl.OnDial(func(net.Conn) {
ctl.OnDial(func(net.Conn) {
dialed++
})
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
// This should succeed.
@@ -184,13 +184,10 @@ func TestStatus_NoReadExistingConn(t *testing.T) {
require.Equal(t, 1, dialed)
// Disable reading on the existing connection.
netCtl.SetCanRead(false)
ctl.SetCanRead(false)
// This should fail because we won't be able to read the response.
require.Error(t, getErr(c.GetUser(context.Background())))
// We should still have dialed once; the connection should have been reused.
require.Equal(t, 1, dialed)
}
func TestStatus_NoWriteExistingConn(t *testing.T) {
@@ -200,17 +197,17 @@ func TestStatus_NoWriteExistingConn(t *testing.T) {
_, _, err := s.CreateUser("user", []byte("pass"))
require.NoError(t, err)
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
var dialed int
netCtl.OnDial(func(net.Conn) {
ctl.OnDial(func(net.Conn) {
dialed++
})
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
proton.WithRetryCount(0),
)
@@ -223,7 +220,7 @@ func TestStatus_NoWriteExistingConn(t *testing.T) {
require.Equal(t, 1, dialed)
// Disable reading on the existing connection.
netCtl.SetCanWrite(false)
ctl.SetCanWrite(false)
// This should fail because we won't be able to write the request.
require.Error(t, c.LabelMessages(context.Background(), []string{"messageID"}, proton.TrashLabel))
@@ -240,7 +237,7 @@ func TestStatus_ContextCancel(t *testing.T) {
var called int
m.AddStatusObserver(func(val proton.Status) {
m.AddStatusObserver(func(proton.Status) {
called++
})
@@ -263,7 +260,7 @@ func TestStatus_ContextTimeout(t *testing.T) {
var called int
m.AddStatusObserver(func(val proton.Status) {
m.AddStatusObserver(func(proton.Status) {
called++
})

View File

@@ -4,16 +4,15 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/go-proton-api/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -21,17 +20,17 @@ func TestConnectionReuse(t *testing.T) {
s := server.New()
defer s.Close()
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
var dialed int
netCtl.OnDial(func(net.Conn) {
ctl.OnDial(func(net.Conn) {
dialed++
})
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
// This should succeed; the resulting connection should be reused.
@@ -69,89 +68,70 @@ func TestAuthRefresh(t *testing.T) {
}
func TestHandleTooManyRequests(t *testing.T) {
var numCalls int
// Create a server with a rate limit of 1 request per second.
s := server.New(server.WithRateLimit(1, time.Second))
defer s.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
numCalls++
var calls []server.Call
if numCalls < 5 {
w.WriteHeader(http.StatusTooManyRequests)
} else {
w.WriteHeader(http.StatusOK)
}
}))
defer ts.Close()
// Watch the calls made.
s.AddCallWatcher(func(call server.Call) {
calls = append(calls, call)
})
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(5),
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.InsecureTransport()),
)
defer m.Close()
// The call should succeed because the 5th retry should succeed (429s are retried).
c := m.NewClient("", "", "")
defer c.Close()
if _, err := c.GetAddresses(context.Background()); err != nil {
t.Fatal("got unexpected error", err)
// Make five calls; they should all succeed, but will be rate limited.
for i := 0; i < 5; i++ {
require.NoError(t, m.Ping(context.Background()))
}
// The server should be called 5 times.
// The first four calls should return 429 and the last call should return 200.
if numCalls != 5 {
t.Fatal("expected numCalls to be 5, instead got", numCalls)
// After each 429 response, we should wait at least the requested duration before making the next request.
for idx, call := range calls {
if call.Status == http.StatusTooManyRequests {
after, err := strconv.Atoi(call.ResponseHeader.Get("Retry-After"))
require.NoError(t, err)
// The next call should be made after the requested duration.
require.True(t, calls[idx+1].Time.After(call.Time.Add(time.Duration(after)*time.Second)))
}
}
}
func TestHandleTooManyRequestsRetryAfter(t *testing.T) {
getDelay := func(iCal int) time.Duration {
return time.Duration(5*1<<iCal) * time.Second
}
func TestHandleTooManyRequests_Malformed(t *testing.T) {
var calls []time.Time
iRetry := -1
lastCall := time.Now()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
currentCall := time.Now()
if iRetry >= 0 {
delay := getDelay(iRetry)
assert.False(t, currentCall.Before(
lastCall.Add(delay)),
"Delay was %v but expected to have %v",
currentCall.Sub(lastCall),
delay,
)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if len(calls) == 0 {
w.Header().Set("Retry-After", "malformed")
w.WriteHeader(http.StatusTooManyRequests)
}
iRetry++
lastCall = currentCall
// test defaul 10sec
if iRetry == 1 {
w.Header().Set("Retry-After", "something")
} else {
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", getDelay(iRetry).Seconds()))
}
w.WriteHeader(http.StatusTooManyRequests)
calls = append(calls, time.Now())
}))
defer ts.Close()
m := proton.New(
proton.WithHostURL(ts.URL),
proton.WithRetryCount(3),
)
m := proton.New(proton.WithHostURL(ts.URL))
defer m.Close()
c := m.NewClient("", "", "")
defer c.Close()
require.NoError(t, m.Ping(context.Background()))
_, err := c.GetAddresses(context.Background())
require.Error(t, err)
// The first call should fail because the Retry-After header is invalid.
// The second call should succeed.
require.Len(t, calls, 2)
// The second call should be made at least 10 seconds after the first call.
require.True(t, calls[1].After(calls[0].Add(10*time.Second)))
}
func TestHandleUnprocessableEntity(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
numCalls++
w.WriteHeader(http.StatusUnprocessableEntity)
}))
@@ -180,7 +160,7 @@ func TestHandleUnprocessableEntity(t *testing.T) {
func TestHandleDialFailure(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
numCalls++
w.WriteHeader(http.StatusOK)
}))
@@ -210,7 +190,7 @@ func TestHandleDialFailure(t *testing.T) {
func TestHandleTooManyDialFailures(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
numCalls++
w.WriteHeader(http.StatusOK)
}))
@@ -242,7 +222,7 @@ func TestHandleTooManyDialFailures(t *testing.T) {
func TestRetriesWithContextTimeout(t *testing.T) {
var numCalls int
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
numCalls++
if numCalls < 5 {
@@ -276,7 +256,7 @@ func TestRetriesWithContextTimeout(t *testing.T) {
}
func TestReturnErrNoConnection(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
@@ -305,7 +285,7 @@ func TestStatusCallbacks(t *testing.T) {
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
statusCh := make(chan proton.Status, 1)

398
netctl.go Normal file
View File

@@ -0,0 +1,398 @@
package proton
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
)
// InsecureTransport returns an http.Transport with InsecureSkipVerify set to true.
func InsecureTransport() *http.Transport {
return &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
}
// ctl can be used to control whether a dialer can dial, and whether the resulting
// connection can read or write.
type NetCtl struct {
canDial bool
dialLimit uint64
dialCount uint64
onDial []func(net.Conn)
dlock sync.RWMutex
canRead bool
readLimit uint64
readCount uint64
readSpeed int
onRead []func([]byte)
rlock sync.RWMutex
canWrite bool
writeLimit uint64
writeCount uint64
writeSpeed int
onWrite []func([]byte)
wlock sync.RWMutex
conns []net.Conn
}
// NewNetCtl returns a new ctl with all fields set to true.
func NewNetCtl() *NetCtl {
return &NetCtl{
canDial: true,
canRead: true,
canWrite: true,
}
}
// SetCanDial sets whether the dialer can dial.
func (c *NetCtl) SetCanDial(canDial bool) {
c.dlock.Lock()
defer c.dlock.Unlock()
c.canDial = canDial
}
// SetDialLimit sets the maximum number of times dialers using this controller can dial.
func (c *NetCtl) SetDialLimit(limit uint64) {
c.dlock.Lock()
defer c.dlock.Unlock()
c.dialLimit = limit
}
// SetCanRead sets whether the connection can read.
func (c *NetCtl) SetCanRead(canRead bool) {
c.dlock.Lock()
defer c.dlock.Unlock()
for _, conn := range c.conns {
conn.Close()
}
c.rlock.Lock()
defer c.rlock.Unlock()
c.canRead = canRead
}
// SetReadLimit sets the maximum number of bytes that can be read.
func (c *NetCtl) SetReadLimit(limit uint64) {
c.dlock.Lock()
defer c.dlock.Unlock()
for _, conn := range c.conns {
conn.Close()
}
c.rlock.Lock()
defer c.rlock.Unlock()
c.readLimit = limit
c.readCount = 0
}
// SetReadSpeed sets the maximum number of bytes that can be read per second.
func (c *NetCtl) SetReadSpeed(speed int) {
c.dlock.Lock()
defer c.dlock.Unlock()
for _, conn := range c.conns {
conn.Close()
}
c.rlock.Lock()
defer c.rlock.Unlock()
c.readSpeed = speed
}
// SetCanWrite sets whether the connection can write.
func (c *NetCtl) SetCanWrite(canWrite bool) {
c.dlock.Lock()
defer c.dlock.Unlock()
for _, conn := range c.conns {
conn.Close()
}
c.wlock.Lock()
defer c.wlock.Unlock()
c.canWrite = canWrite
}
// SetWriteLimit sets the maximum number of bytes that can be written.
func (c *NetCtl) SetWriteLimit(limit uint64) {
c.dlock.Lock()
defer c.dlock.Unlock()
for _, conn := range c.conns {
conn.Close()
}
c.wlock.Lock()
defer c.wlock.Unlock()
c.writeLimit = limit
c.writeCount = 0
}
// SetWriteSpeed sets the maximum number of bytes that can be written per second.
func (c *NetCtl) SetWriteSpeed(speed int) {
c.dlock.Lock()
defer c.dlock.Unlock()
for _, conn := range c.conns {
conn.Close()
}
c.wlock.Lock()
defer c.wlock.Unlock()
c.writeSpeed = speed
}
// OnDial adds a callback that is called with the created connection when a dial is successful.
func (c *NetCtl) OnDial(f func(net.Conn)) {
c.dlock.Lock()
defer c.dlock.Unlock()
c.onDial = append(c.onDial, f)
}
// OnRead adds a callback that is called with the read bytes when a read is successful.
func (c *NetCtl) OnRead(fn func([]byte)) {
c.rlock.Lock()
defer c.rlock.Unlock()
c.onRead = append(c.onRead, fn)
}
// OnWrite adds a callback that is called with the written bytes when a write is successful.
func (c *NetCtl) OnWrite(fn func([]byte)) {
c.wlock.Lock()
defer c.wlock.Unlock()
c.onWrite = append(c.onWrite, fn)
}
// Disable is equivalent to disallowing dial, read and write.
func (c *NetCtl) Disable() {
c.SetCanDial(false)
c.SetCanRead(false)
c.SetCanWrite(false)
}
// Enable is equivalent to allowing dial, read and write.
func (c *NetCtl) Enable() {
c.SetCanDial(true)
c.SetCanRead(true)
c.SetCanWrite(true)
}
// NewDialer returns a new dialer controlled by the ctl.
func (c *NetCtl) NewRoundTripper(tlsConfig *tls.Config) http.RoundTripper {
return &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return c.dial(ctx, &net.Dialer{}, network, addr)
},
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return c.dial(ctx, &tls.Dialer{Config: tlsConfig}, network, addr)
},
TLSClientConfig: tlsConfig,
}
}
// ctxDialer implements DialContext.
type ctxDialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
// dial dials using d, but only if the controller allows it.
func (c *NetCtl) dial(ctx context.Context, dialer ctxDialer, network, addr string) (net.Conn, error) {
c.dlock.Lock()
defer c.dlock.Unlock()
if !c.canDial {
return nil, errors.New("dial failed (not allowed)")
}
if c.dialLimit > 0 && c.dialCount >= c.dialLimit {
return nil, errors.New("dial failed (limit reached)")
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
c.dialCount++
for _, fn := range c.onDial {
fn(conn)
}
c.conns = append(c.conns, conn)
return newConn(conn, c), nil
}
// read reads from r, but only if the controller allows it.
func (c *NetCtl) read(r io.Reader, b []byte) (int, error) {
c.rlock.Lock()
defer c.rlock.Unlock()
if !c.canRead {
return 0, errors.New("read failed (not allowed)")
}
if c.readLimit > 0 && c.readCount >= c.readLimit {
return 0, errors.New("read failed (limit reached)")
}
var rem uint64
if c.readLimit > 0 && c.readLimit-c.readCount < uint64(len(b)) {
rem = c.readLimit - c.readCount
} else {
rem = uint64(len(b))
}
c.rlock.Unlock()
n, err := newSlowReader(r, c.readSpeed).Read(b[:rem])
c.rlock.Lock()
c.readCount += uint64(n)
for _, fn := range c.onRead {
fn(b[:n])
}
return n, err
}
// write writes to w, but only if the controller allows it.
func (c *NetCtl) write(w io.Writer, b []byte) (int, error) {
c.wlock.Lock()
defer c.wlock.Unlock()
if !c.canWrite {
return 0, errors.New("write failed (not allowed)")
}
if c.writeLimit > 0 && c.writeCount >= c.writeLimit {
return 0, errors.New("write failed (limit exceeded)")
}
var rem uint64
if c.writeLimit > 0 && c.writeLimit-c.writeCount < uint64(len(b)) {
rem = c.writeLimit - c.writeCount
} else {
rem = uint64(len(b))
}
c.wlock.Unlock()
n, err := newSlowWriter(w, c.writeSpeed).Write(b[:rem])
c.wlock.Lock()
c.writeCount += uint64(n)
for _, fn := range c.onWrite {
fn(b[:n])
}
if uint64(n) < rem {
return n, fmt.Errorf("write incomplete (limit reached)")
}
return n, err
}
// conn is a wrapper around net.conn that can be used to control whether a connection can read or write.
type conn struct {
net.Conn
ctl *NetCtl
}
func newConn(c net.Conn, ctl *NetCtl) *conn {
return &conn{
Conn: c,
ctl: ctl,
}
}
// Read reads from the wrapped connection, but only if the controller allows it.
func (c *conn) Read(b []byte) (int, error) {
return c.ctl.read(c.Conn, b)
}
// Write writes to the wrapped connection, but only if the controller allows it.
func (c *conn) Write(b []byte) (int, error) {
return c.ctl.write(c.Conn, b)
}
// slowReader is an io.Reader that reads at a fixed rate.
type slowReader struct {
r io.Reader
// bytesPerSec is the number of bytes to read per second.
bytesPerSec int
}
func newSlowReader(r io.Reader, bytesPerSec int) *slowReader {
return &slowReader{
r: r,
bytesPerSec: bytesPerSec,
}
}
func (r *slowReader) Read(b []byte) (int, error) {
start := time.Now()
n, err := r.r.Read(b)
if r.bytesPerSec > 0 {
time.Sleep(time.Until(start.Add(time.Duration(n*r.bytesPerSec) * time.Second)))
}
return n, err
}
// slowWriter is an io.Writer that writes at a fixed rate.
type slowWriter struct {
w io.Writer
// bytesPerSec is the number of bytes to write per second.
bytesPerSec int
}
func newSlowWriter(w io.Writer, bytesPerSec int) *slowWriter {
return &slowWriter{
w: w,
bytesPerSec: bytesPerSec,
}
}
func (w *slowWriter) Write(b []byte) (int, error) {
start := time.Now()
n, err := w.w.Write(b)
if w.bytesPerSec > 0 {
time.Sleep(time.Until(start.Add(time.Duration(n*w.bytesPerSec) * time.Second)))
}
return n, err
}

View File

@@ -14,7 +14,7 @@ import (
func TestNetCtl_ReadLimit(t *testing.T) {
// Create a test http server that writes 100 bytes.
// Including the header, this is 217 bytes (100 bytes + 117 bytes).
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if _, err := w.Write(make([]byte, 100)); err != nil {
t.Fatal(err)
}
@@ -22,16 +22,16 @@ func TestNetCtl_ReadLimit(t *testing.T) {
defer ts.Close()
// Create a new net controller.
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
ctl.SetReadLimit(300)
// Create a new http client with the dialer.
client := &http.Client{
Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}),
}
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
netCtl.SetReadLimit(300)
// This should succeed.
if resp, err := client.Get(ts.URL); err != nil {
t.Fatal(err)
@@ -47,7 +47,7 @@ func TestNetCtl_ReadLimit(t *testing.T) {
func TestNetCtl_WriteLimit(t *testing.T) {
// Create a test http server that reads the given body.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if _, err := io.ReadAll(r.Body); err != nil {
t.Fatal(err)
}
@@ -55,16 +55,16 @@ func TestNetCtl_WriteLimit(t *testing.T) {
defer ts.Close()
// Create a new net controller.
netCtl := proton.NewNetCtl()
ctl := proton.NewNetCtl()
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
ctl.SetWriteLimit(300)
// Create a new http client with the dialer.
client := &http.Client{
Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}),
}
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
netCtl.SetWriteLimit(300)
// This should succeed.
if resp, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err != nil {
t.Fatal(err)

View File

@@ -62,37 +62,31 @@ func updateTime(_ *resty.Client, res *resty.Response) error {
return nil
}
// nolint:gosec
func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) {
if res.StatusCode() == http.StatusTooManyRequests {
if after := res.Header().Get("Retry-After"); after != "" {
l := logrus.WithFields(logrus.Fields{
"pkg": "go-proton-api",
"statusCode": res.StatusCode(),
"url": res.Request.URL,
"verb": res.Request.Method,
})
seconds, err := strconv.Atoi(after)
if err != nil {
l.WithField("after", after).WithError(err).Warning(
"Cannot convert Retry-After to a number, continue with default 10 second cooldown.",
)
seconds = 10
}
// To avoid spikes when all clients retry at the same time, we add some random wait.
seconds += rand.Intn(10) //nolint:gosec // It is OK to use weak random number generator here.
l = l.WithField("delay", seconds)
// Maximum retry time in client is is one minute. But
// here wait times can be longer e.g. high API load
l.Warn("Delay the retry after http response")
return time.Duration(seconds) * time.Second, nil
}
// 0 and no error means default behaviour which is exponential backoff with jitter.
if res.StatusCode() != http.StatusTooManyRequests {
return 0, nil
}
// 0 and no error means default behaviour which is exponential backoff with jitter.
return 0, nil
// Parse the Retry-After header, or fallback to 10 seconds.
after, err := strconv.Atoi(res.Header().Get("Retry-After"))
if err != nil {
after = 10
}
// Add some jitter to the delay.
after += rand.Intn(10)
logrus.WithFields(logrus.Fields{
"pkg": "go-proton-api",
"status": res.StatusCode(),
"url": res.Request.URL,
"method": res.Request.Method,
"after": after,
}).Warn("Too many requests, retrying after delay")
return time.Duration(after) * time.Second, nil
}
func catchTooManyRequests(res *resty.Response, _ error) bool {

View File

@@ -0,0 +1,25 @@
package backend
import "github.com/ProtonMail/gopenpgp/v2/crypto"
var preCompKey *crypto.Key
func init() {
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
if err != nil {
panic(err)
}
preCompKey = key
}
// FastGenerateKey is a fast version of GenerateKey that uses a pre-computed key.
// This is useful for testing but is incredibly insecure.
func FastGenerateKey(_, _ string, passphrase []byte, _ string, _ int) (string, error) {
encKey, err := preCompKey.Lock(passphrase)
if err != nil {
return "", err
}
return encKey.Armor()
}

View File

@@ -3,6 +3,7 @@ package server
import (
"net/http"
"net/url"
"time"
)
type Call struct {
@@ -10,6 +11,9 @@ type Call struct {
Method string
Status int
Time time.Time
Duration time.Duration
RequestHeader http.Header
RequestBody []byte

View File

@@ -1,22 +1,7 @@
package server
import (
"github.com/ProtonMail/go-proton-api/server/backend"
"github.com/ProtonMail/gopenpgp/v2/crypto"
)
import "github.com/ProtonMail/go-proton-api/server/backend"
func init() {
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
if err != nil {
panic(err)
}
backend.GenerateKey = func(_, _ string, passphrase []byte, _ string, _ int) (string, error) {
encKey, err := key.Lock(passphrase)
if err != nil {
return "", err
}
return encKey.Armor()
}
backend.GenerateKey = backend.FastGenerateKey
}

51
server/rate_limit.go Normal file
View File

@@ -0,0 +1,51 @@
package server
import (
"sync"
"time"
)
// rateLimiter is a rate limiter for the server.
// If more than limit requests are made in the time window, the server will return 429.
type rateLimiter struct {
// limit is the rate limit to apply to the server.
limit int
// window is the window in which to apply the rate limit.
window time.Duration
// nextReset is the time at which the rate limit will reset.
nextReset time.Time
// count is the number of calls made to the server.
count int
// countLock is a mutex for the callCount.
countLock sync.Mutex
}
func newRateLimiter(limit int, window time.Duration) *rateLimiter {
return &rateLimiter{
limit: limit,
window: window,
}
}
// exceeded checks the rate limit and returns how long to wait before the next request.
func (r *rateLimiter) exceeded() time.Duration {
r.countLock.Lock()
defer r.countLock.Unlock()
if time.Now().After(r.nextReset) {
r.count = 0
r.nextReset = time.Now().Add(r.window)
}
r.count++
if r.count > r.limit {
return time.Until(r.nextReset)
}
return 0
}

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@@ -20,6 +21,7 @@ func initRouter(s *Server) {
s.r.Use(
s.requireValidAppVersion(),
s.setSessionCookie(),
s.applyRateLimit(),
)
if core := s.r.Group("/core/v4"); core != nil {
@@ -171,8 +173,23 @@ func (s *Server) setSessionCookie() gin.HandlerFunc {
}
}
func (s *Server) applyRateLimit() gin.HandlerFunc {
return func(c *gin.Context) {
if s.rateLimit == nil {
return
}
if wait := s.rateLimit.exceeded(); wait > 0 {
c.Header("Retry-After", strconv.Itoa(int(wait.Seconds())))
c.AbortWithStatus(http.StatusTooManyRequests)
}
}
}
func (s *Server) logCalls() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
req, err := io.ReadAll(c.Request.Body)
if err != nil {
panic(err)
@@ -199,6 +216,9 @@ func (s *Server) logCalls() gin.HandlerFunc {
Method: c.Request.Method,
Status: c.Writer.Status(),
Time: start,
Duration: time.Since(start),
RequestHeader: c.Request.Header,
RequestBody: req,

View File

@@ -47,6 +47,9 @@ type Server struct {
// offline is whether to pretend the server is offline and return 5xx errors.
offline bool
// rateLimit is the rate limiter for the server.
rateLimit *rateLimiter
}
func New(opts ...Option) *Server {
@@ -69,7 +72,7 @@ func (s *Server) GetProxyURL() string {
return s.s.URL + "/proxy"
}
// GetDomain returns the domain of the server.
// GetDomain returns the domain of the server (e.g. "proton.local").
func (s *Server) GetDomain() string {
return s.domain
}

View File

@@ -12,11 +12,12 @@ import (
)
type serverBuilder struct {
withTLS bool
domain string
logger io.Writer
origin string
cacher AuthCacher
withTLS bool
domain string
logger io.Writer
origin string
cacher AuthCacher
rateLimiter *rateLimiter
}
func newServerBuilder() *serverBuilder {
@@ -46,6 +47,7 @@ func (builder *serverBuilder) build() *Server {
domain: builder.domain,
proxyOrigin: builder.origin,
authCacher: builder.cacher,
rateLimit: builder.rateLimiter,
}
if builder.withTLS {
@@ -143,3 +145,19 @@ type withAuthCache struct {
func (opt withAuthCache) config(builder *serverBuilder) {
builder.cacher = opt.cacher
}
func WithRateLimit(limit int, window time.Duration) Option {
return &withRateLimit{
limit: limit,
window: window,
}
}
type withRateLimit struct {
limit int
window time.Duration
}
func (opt withRateLimit) config(builder *serverBuilder) {
builder.rateLimiter = newRateLimiter(opt.limit, opt.window)
}

View File

@@ -77,7 +77,7 @@ func TestServer_Ping(t *testing.T) {
m := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
var status proton.Status
@@ -1313,7 +1313,7 @@ func TestServer_Messages_Fetch(t *testing.T) {
mm := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
defer mm.Close()
@@ -1355,7 +1355,7 @@ func TestServer_Messages_Status(t *testing.T) {
mm := proton.New(
proton.WithHostURL(s.GetHostURL()),
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
)
defer mm.Close()

View File

@@ -12,6 +12,8 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi
userKR, err := user.Keys.Unlock(saltedKeyPass, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to unlock user keys: %w", err)
} else if userKR.CountDecryptionEntities() == 0 {
return nil, nil, fmt.Errorf("failed to unlock any user keys")
}
addrKRs := make(map[string]*crypto.KeyRing)
@@ -19,9 +21,13 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi
for idx, addrKR := range parallel.Map(runtime.NumCPU(), addresses, func(addr Address) *crypto.KeyRing {
return addr.Keys.TryUnlock(saltedKeyPass, userKR)
}) {
if addrKR != nil {
addrKRs[addresses[idx].ID] = addrKR
if addrKR.CountDecryptionEntities() == 0 {
continue
} else if addrKR == nil {
continue
}
addrKRs[addresses[idx].ID] = addrKR
}
if len(addrKRs) == 0 {