mirror of
https://github.com/ProtonMail/go-proton-api.git
synced 2025-12-23 23:57:50 -05:00
refactor: NetCtl, transport, dialer, rate limiting
This commit is contained in:
28
auth_test.go
28
auth_test.go
@@ -16,7 +16,7 @@ func TestAuth(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
_, _, err := s.CreateUser("user", []byte("password"))
|
||||
_, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
m := proton.New(
|
||||
@@ -26,14 +26,14 @@ func TestAuth(t *testing.T) {
|
||||
defer m.Close()
|
||||
|
||||
// Create one session.
|
||||
c1, auth1, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
|
||||
c1, auth1, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Revoke all other sessions.
|
||||
require.NoError(t, c1.AuthRevokeAll(context.Background()))
|
||||
|
||||
// Create another session.
|
||||
c2, _, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
|
||||
c2, _, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// There should be two sessions.
|
||||
@@ -61,7 +61,7 @@ func TestAuth_Refresh(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Create a user on the server.
|
||||
userID, _, err := s.CreateUser("user", []byte("password"))
|
||||
userID, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// The auth is valid for 4 seconds.
|
||||
@@ -74,7 +74,7 @@ func TestAuth_Refresh(t *testing.T) {
|
||||
defer m.Close()
|
||||
|
||||
// Create one session for the user.
|
||||
c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
|
||||
c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userID, auth.UserID)
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestAuth_Refresh(t *testing.T) {
|
||||
{
|
||||
user, err := c.GetUser(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "username", user.Name)
|
||||
require.Equal(t, "user", user.Name)
|
||||
require.Equal(t, userID, user.ID)
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ func TestAuth_Refresh(t *testing.T) {
|
||||
{
|
||||
user, err := c.GetUser(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "username", user.Name)
|
||||
require.Equal(t, "user", user.Name)
|
||||
require.Equal(t, userID, user.ID)
|
||||
}
|
||||
}
|
||||
@@ -106,7 +106,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Create a user on the server.
|
||||
userID, _, err := s.CreateUser("user", []byte("password"))
|
||||
userID, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// The auth is valid for 4 seconds.
|
||||
@@ -118,7 +118,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
|
||||
)
|
||||
defer m.Close()
|
||||
|
||||
c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
|
||||
c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userID, auth.UserID)
|
||||
|
||||
@@ -128,7 +128,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
|
||||
parallel.Do(runtime.NumCPU(), 100, func(idx int) {
|
||||
user, err := c.GetUser(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "username", user.Name)
|
||||
require.Equal(t, "user", user.Name)
|
||||
require.Equal(t, userID, user.ID)
|
||||
})
|
||||
|
||||
@@ -139,7 +139,7 @@ func TestAuth_Refresh_Multi(t *testing.T) {
|
||||
parallel.Do(runtime.NumCPU(), 100, func(idx int) {
|
||||
user, err := c.GetUser(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "username", user.Name)
|
||||
require.Equal(t, "user", user.Name)
|
||||
require.Equal(t, userID, user.ID)
|
||||
})
|
||||
}
|
||||
@@ -149,7 +149,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
|
||||
defer s.Close()
|
||||
|
||||
// Create a user on the server.
|
||||
userID, _, err := s.CreateUser("user", []byte("password"))
|
||||
userID, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
m := proton.New(
|
||||
@@ -159,7 +159,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
|
||||
defer m.Close()
|
||||
|
||||
// Create one session for the user.
|
||||
c, auth, err := m.NewClientWithLogin(context.Background(), "username", []byte("password"))
|
||||
c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userID, auth.UserID)
|
||||
|
||||
@@ -172,7 +172,7 @@ func TestAuth_Refresh_Deauth(t *testing.T) {
|
||||
{
|
||||
user, err := c.GetUser(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "username", user.Name)
|
||||
require.Equal(t, "user", user.Name)
|
||||
require.Equal(t, userID, user.ID)
|
||||
}
|
||||
|
||||
|
||||
311
dialer.go
311
dialer.go
@@ -1,311 +0,0 @@
|
||||
package proton
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func InsecureTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
}
|
||||
|
||||
// NetCtl can be used to control whether a dialer can dial, and whether the resulting
|
||||
// connection can read or write.
|
||||
type NetCtl struct {
|
||||
canDial atomicBool
|
||||
dialLimit atomicUint64
|
||||
|
||||
canRead atomicBool
|
||||
readLimit atomicUint64
|
||||
|
||||
canWrite atomicBool
|
||||
writeLimit atomicUint64
|
||||
|
||||
onDial []func(net.Conn)
|
||||
onRead []func([]byte)
|
||||
onWrite []func([]byte)
|
||||
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// NewNetCtl returns a new NetCtl with all fields set to true.
|
||||
func NewNetCtl() *NetCtl {
|
||||
return &NetCtl{
|
||||
canDial: atomicBool{b32(true)},
|
||||
canRead: atomicBool{b32(true)},
|
||||
canWrite: atomicBool{b32(true)},
|
||||
}
|
||||
}
|
||||
|
||||
// SetCanDial sets whether the dialer can dial.
|
||||
func (c *NetCtl) SetCanDial(canDial bool) {
|
||||
c.canDial.Store(canDial)
|
||||
}
|
||||
|
||||
// SetDialLimit sets the maximum number of times dialers using this controller can dial.
|
||||
func (c *NetCtl) SetDialLimit(limit uint64) {
|
||||
c.dialLimit.Store(limit)
|
||||
}
|
||||
|
||||
// SetCanRead sets whether the connection can read.
|
||||
func (c *NetCtl) SetCanRead(canRead bool) {
|
||||
c.canRead.Store(canRead)
|
||||
}
|
||||
|
||||
// SetReadLimit sets the maximum number of bytes that can be read.
|
||||
func (c *NetCtl) SetReadLimit(limit uint64) {
|
||||
c.readLimit.Store(limit)
|
||||
}
|
||||
|
||||
// SetCanWrite sets whether the connection can write.
|
||||
func (c *NetCtl) SetCanWrite(canWrite bool) {
|
||||
c.canWrite.Store(canWrite)
|
||||
}
|
||||
|
||||
// SetWriteLimit sets the maximum number of bytes that can be written.
|
||||
func (c *NetCtl) SetWriteLimit(limit uint64) {
|
||||
c.writeLimit.Store(limit)
|
||||
}
|
||||
|
||||
// OnDial adds a callback that is called with the created connection when a dial is successful.
|
||||
func (c *NetCtl) OnDial(f func(net.Conn)) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.onDial = append(c.onDial, f)
|
||||
}
|
||||
|
||||
// OnRead adds a callback that is called with the read bytes when a read is successful.
|
||||
func (c *NetCtl) OnRead(f func([]byte)) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.onRead = append(c.onRead, f)
|
||||
}
|
||||
|
||||
// OnWrite adds a callback that is called with the written bytes when a write is successful.
|
||||
func (c *NetCtl) OnWrite(f func([]byte)) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.onWrite = append(c.onWrite, f)
|
||||
}
|
||||
|
||||
// Disable is equivalent to disallowing dial, read and write.
|
||||
func (c *NetCtl) Disable() {
|
||||
c.SetCanDial(false)
|
||||
c.SetCanRead(false)
|
||||
c.SetCanWrite(false)
|
||||
}
|
||||
|
||||
// Enable is equivalent to allowing dial, read and write.
|
||||
func (c *NetCtl) Enable() {
|
||||
c.SetCanDial(true)
|
||||
c.SetCanRead(true)
|
||||
c.SetCanWrite(true)
|
||||
}
|
||||
|
||||
// Conn is a wrapper around net.Conn that can be used to control whether a connection can read or write.
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
|
||||
ctl *NetCtl
|
||||
|
||||
readLimiter *readLimiter
|
||||
writeLimiter *writeLimiter
|
||||
}
|
||||
|
||||
// Read reads from the wrapped connection, but only if the controller allows it.
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
if !c.ctl.canRead.Load() {
|
||||
return 0, errors.New("cannot read")
|
||||
}
|
||||
|
||||
n, err := c.readLimiter.read(c.Conn, b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
for _, f := range c.ctl.onRead {
|
||||
f(b[:n])
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write writes to the wrapped connection, but only if the controller allows it.
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
if !c.ctl.canWrite.Load() {
|
||||
return 0, errors.New("cannot write")
|
||||
}
|
||||
|
||||
n, err := c.writeLimiter.write(c.Conn, b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
for _, f := range c.ctl.onWrite {
|
||||
f(b[:n])
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Dialer performs network dialing, but only if the controller allows it.
|
||||
type Dialer struct {
|
||||
ctl *NetCtl
|
||||
|
||||
netDialer *net.Dialer
|
||||
tlsDialer *tls.Dialer
|
||||
tlsConfig *tls.Config
|
||||
|
||||
readLimiter *readLimiter
|
||||
writeLimiter *writeLimiter
|
||||
|
||||
dialCount atomicUint64
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer using the given net controller.
|
||||
// It optionally uses a provided tls config.
|
||||
func NewDialer(ctl *NetCtl, tlsConfig *tls.Config) *Dialer {
|
||||
return &Dialer{
|
||||
ctl: ctl,
|
||||
|
||||
netDialer: &net.Dialer{},
|
||||
tlsDialer: &tls.Dialer{Config: tlsConfig},
|
||||
tlsConfig: tlsConfig,
|
||||
|
||||
readLimiter: newReadLimiter(ctl),
|
||||
writeLimiter: newWriteLimiter(ctl),
|
||||
|
||||
dialCount: atomicUint64{0},
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext dials a network connection, but only if the controller allows it.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return d.dialWithDialer(ctx, network, addr, d.netDialer)
|
||||
}
|
||||
|
||||
// DialTLSContext dials a TLS network connection, but only if the controller allows it.
|
||||
func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return d.dialWithDialer(ctx, network, addr, d.tlsDialer)
|
||||
}
|
||||
|
||||
// dialWithDialer dials a network connection using the given dialer, but only if the controller allows it.
|
||||
func (d *Dialer) dialWithDialer(ctx context.Context, network, addr string, dialer dialer) (net.Conn, error) {
|
||||
if !d.ctl.canDial.Load() {
|
||||
return nil, errors.New("cannot dial")
|
||||
}
|
||||
|
||||
if limit := d.ctl.dialLimit.Load(); limit > 0 && d.dialCount.Load() >= limit {
|
||||
return nil, errors.New("dial limit reached")
|
||||
} else {
|
||||
d.dialCount.Add(1)
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.ctl.lock.Lock()
|
||||
defer d.ctl.lock.Unlock()
|
||||
|
||||
for _, f := range d.ctl.onDial {
|
||||
f(conn)
|
||||
}
|
||||
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
ctl: d.ctl,
|
||||
|
||||
readLimiter: d.readLimiter,
|
||||
writeLimiter: d.writeLimiter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetRoundTripper returns a new http.RoundTripper that uses the dialer.
|
||||
func (d *Dialer) GetRoundTripper() http.RoundTripper {
|
||||
return &http.Transport{
|
||||
DialContext: d.DialContext,
|
||||
DialTLSContext: d.DialTLSContext,
|
||||
TLSClientConfig: d.tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
type dialer interface {
|
||||
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type readLimiter struct {
|
||||
ctl *NetCtl
|
||||
|
||||
count atomicUint64
|
||||
}
|
||||
|
||||
// newReadLimiter returns a new io.Reader that reads from r, but only up to limit bytes.
|
||||
func newReadLimiter(ctl *NetCtl) *readLimiter {
|
||||
return &readLimiter{
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
func (limiter *readLimiter) read(r io.Reader, b []byte) (int, error) {
|
||||
if limit := limiter.ctl.readLimit.Load(); limit > 0 && limiter.count.Load() >= limit {
|
||||
return 0, fmt.Errorf("refusing to read: read limit reached")
|
||||
}
|
||||
|
||||
n, err := r.Read(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
if limit := limiter.ctl.readLimit.Load(); limit > 0 {
|
||||
if new := limiter.count.Add(uint64(n)); new >= limit {
|
||||
return 0, fmt.Errorf("read failed: read limit reached")
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
type writeLimiter struct {
|
||||
ctl *NetCtl
|
||||
|
||||
count atomicUint64
|
||||
}
|
||||
|
||||
// newWriteLimiter returns a new io.Writer that writes to w, but only up to limit bytes.
|
||||
func newWriteLimiter(ctl *NetCtl) *writeLimiter {
|
||||
return &writeLimiter{
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
func (limiter *writeLimiter) write(w io.Writer, b []byte) (int, error) {
|
||||
if limit := limiter.ctl.writeLimit.Load(); limit > 0 && limiter.count.Load() >= limit {
|
||||
return 0, fmt.Errorf("refusing to write: write limit reached")
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
if limit := limiter.ctl.writeLimit.Load(); limit > 0 {
|
||||
if new := limiter.count.Add(uint64(n)); new >= limit {
|
||||
return 0, fmt.Errorf("write failed: write limit reached")
|
||||
}
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
@@ -22,13 +22,13 @@ func TestEventStreamer(t *testing.T) {
|
||||
proton.WithTransport(proton.InsecureTransport()),
|
||||
)
|
||||
|
||||
_, _, err := s.CreateUser("user", []byte("password"))
|
||||
_, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
c, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))
|
||||
c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
createTestMessages(t, c, "password", 10)
|
||||
createTestMessages(t, c, "pass", 10)
|
||||
|
||||
latestEventID, err := c.GetLatestEventID(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -53,7 +53,7 @@ func TestEventStreamer(t *testing.T) {
|
||||
c.Close()
|
||||
|
||||
// Create a new client and perform some actions with it to generate more events.
|
||||
cc, _, err := m.NewClientWithLogin(ctx, "username", []byte("password"))
|
||||
cc, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
defer cc.Close()
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ func TestStatus(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,7 +39,7 @@ func TestStatus(t *testing.T) {
|
||||
require.Zero(t, called)
|
||||
|
||||
// Now we simulate a network failure.
|
||||
netCtl.Disable()
|
||||
ctl.Disable()
|
||||
|
||||
// This should fail.
|
||||
require.Error(t, m.Ping(context.Background()))
|
||||
@@ -49,7 +49,7 @@ func TestStatus(t *testing.T) {
|
||||
require.Equal(t, proton.StatusDown, status)
|
||||
|
||||
// Now we simulate a network restoration.
|
||||
netCtl.Enable()
|
||||
ctl.Enable()
|
||||
|
||||
// This should succeed.
|
||||
require.NoError(t, m.Ping(context.Background()))
|
||||
@@ -63,11 +63,11 @@ func TestStatus_NoDial(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -81,7 +81,7 @@ func TestStatus_NoDial(t *testing.T) {
|
||||
})
|
||||
|
||||
// Disable dialing.
|
||||
netCtl.SetCanDial(false)
|
||||
ctl.SetCanDial(false)
|
||||
|
||||
// This should fail.
|
||||
require.Error(t, m.Ping(context.Background()))
|
||||
@@ -95,11 +95,11 @@ func TestStatus_NoRead(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -113,7 +113,7 @@ func TestStatus_NoRead(t *testing.T) {
|
||||
})
|
||||
|
||||
// Disable reading.
|
||||
netCtl.SetCanRead(false)
|
||||
ctl.SetCanRead(false)
|
||||
|
||||
// This should fail.
|
||||
require.Error(t, m.Ping(context.Background()))
|
||||
@@ -127,11 +127,11 @@ func TestStatus_NoWrite(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -145,7 +145,7 @@ func TestStatus_NoWrite(t *testing.T) {
|
||||
})
|
||||
|
||||
// Disable writing.
|
||||
netCtl.SetCanWrite(false)
|
||||
ctl.SetCanWrite(false)
|
||||
|
||||
// This should fail.
|
||||
require.Error(t, m.Ping(context.Background()))
|
||||
@@ -162,17 +162,17 @@ func TestStatus_NoReadExistingConn(t *testing.T) {
|
||||
_, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
var dialed int
|
||||
|
||||
netCtl.OnDial(func(net.Conn) {
|
||||
ctl.OnDial(func(net.Conn) {
|
||||
dialed++
|
||||
})
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
// This should succeed.
|
||||
@@ -184,13 +184,10 @@ func TestStatus_NoReadExistingConn(t *testing.T) {
|
||||
require.Equal(t, 1, dialed)
|
||||
|
||||
// Disable reading on the existing connection.
|
||||
netCtl.SetCanRead(false)
|
||||
ctl.SetCanRead(false)
|
||||
|
||||
// This should fail because we won't be able to read the response.
|
||||
require.Error(t, getErr(c.GetUser(context.Background())))
|
||||
|
||||
// We should still have dialed once; the connection should have been reused.
|
||||
require.Equal(t, 1, dialed)
|
||||
}
|
||||
|
||||
func TestStatus_NoWriteExistingConn(t *testing.T) {
|
||||
@@ -200,17 +197,17 @@ func TestStatus_NoWriteExistingConn(t *testing.T) {
|
||||
_, _, err := s.CreateUser("user", []byte("pass"))
|
||||
require.NoError(t, err)
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
var dialed int
|
||||
|
||||
netCtl.OnDial(func(net.Conn) {
|
||||
ctl.OnDial(func(net.Conn) {
|
||||
dialed++
|
||||
})
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
proton.WithRetryCount(0),
|
||||
)
|
||||
|
||||
@@ -223,7 +220,7 @@ func TestStatus_NoWriteExistingConn(t *testing.T) {
|
||||
require.Equal(t, 1, dialed)
|
||||
|
||||
// Disable reading on the existing connection.
|
||||
netCtl.SetCanWrite(false)
|
||||
ctl.SetCanWrite(false)
|
||||
|
||||
// This should fail because we won't be able to write the request.
|
||||
require.Error(t, c.LabelMessages(context.Background(), []string{"messageID"}, proton.TrashLabel))
|
||||
@@ -240,7 +237,7 @@ func TestStatus_ContextCancel(t *testing.T) {
|
||||
|
||||
var called int
|
||||
|
||||
m.AddStatusObserver(func(val proton.Status) {
|
||||
m.AddStatusObserver(func(proton.Status) {
|
||||
called++
|
||||
})
|
||||
|
||||
@@ -263,7 +260,7 @@ func TestStatus_ContextTimeout(t *testing.T) {
|
||||
|
||||
var called int
|
||||
|
||||
m.AddStatusObserver(func(val proton.Status) {
|
||||
m.AddStatusObserver(func(proton.Status) {
|
||||
called++
|
||||
})
|
||||
|
||||
|
||||
120
manager_test.go
120
manager_test.go
@@ -4,16 +4,15 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/ProtonMail/go-proton-api/server"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -21,17 +20,17 @@ func TestConnectionReuse(t *testing.T) {
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
var dialed int
|
||||
|
||||
netCtl.OnDial(func(net.Conn) {
|
||||
ctl.OnDial(func(net.Conn) {
|
||||
dialed++
|
||||
})
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
// This should succeed; the resulting connection should be reused.
|
||||
@@ -69,89 +68,70 @@ func TestAuthRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleTooManyRequests(t *testing.T) {
|
||||
var numCalls int
|
||||
// Create a server with a rate limit of 1 request per second.
|
||||
s := server.New(server.WithRateLimit(1, time.Second))
|
||||
defer s.Close()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
numCalls++
|
||||
var calls []server.Call
|
||||
|
||||
if numCalls < 5 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
// Watch the calls made.
|
||||
s.AddCallWatcher(func(call server.Call) {
|
||||
calls = append(calls, call)
|
||||
})
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(ts.URL),
|
||||
proton.WithRetryCount(5),
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.InsecureTransport()),
|
||||
)
|
||||
defer m.Close()
|
||||
|
||||
// The call should succeed because the 5th retry should succeed (429s are retried).
|
||||
c := m.NewClient("", "", "")
|
||||
defer c.Close()
|
||||
|
||||
if _, err := c.GetAddresses(context.Background()); err != nil {
|
||||
t.Fatal("got unexpected error", err)
|
||||
// Make five calls; they should all succeed, but will be rate limited.
|
||||
for i := 0; i < 5; i++ {
|
||||
require.NoError(t, m.Ping(context.Background()))
|
||||
}
|
||||
|
||||
// The server should be called 5 times.
|
||||
// The first four calls should return 429 and the last call should return 200.
|
||||
if numCalls != 5 {
|
||||
t.Fatal("expected numCalls to be 5, instead got", numCalls)
|
||||
// After each 429 response, we should wait at least the requested duration before making the next request.
|
||||
for idx, call := range calls {
|
||||
if call.Status == http.StatusTooManyRequests {
|
||||
after, err := strconv.Atoi(call.ResponseHeader.Get("Retry-After"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// The next call should be made after the requested duration.
|
||||
require.True(t, calls[idx+1].Time.After(call.Time.Add(time.Duration(after)*time.Second)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTooManyRequestsRetryAfter(t *testing.T) {
|
||||
getDelay := func(iCal int) time.Duration {
|
||||
return time.Duration(5*1<<iCal) * time.Second
|
||||
}
|
||||
func TestHandleTooManyRequests_Malformed(t *testing.T) {
|
||||
var calls []time.Time
|
||||
|
||||
iRetry := -1
|
||||
lastCall := time.Now()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
currentCall := time.Now()
|
||||
|
||||
if iRetry >= 0 {
|
||||
delay := getDelay(iRetry)
|
||||
assert.False(t, currentCall.Before(
|
||||
lastCall.Add(delay)),
|
||||
"Delay was %v but expected to have %v",
|
||||
currentCall.Sub(lastCall),
|
||||
delay,
|
||||
)
|
||||
}
|
||||
|
||||
iRetry++
|
||||
lastCall = currentCall
|
||||
|
||||
// test defaul 10sec
|
||||
if iRetry == 1 {
|
||||
w.Header().Set("Retry-After", "something")
|
||||
} else {
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", getDelay(iRetry).Seconds()))
|
||||
}
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
if len(calls) == 0 {
|
||||
w.Header().Set("Retry-After", "malformed")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
calls = append(calls, time.Now())
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(ts.URL),
|
||||
proton.WithRetryCount(3),
|
||||
)
|
||||
m := proton.New(proton.WithHostURL(ts.URL))
|
||||
defer m.Close()
|
||||
|
||||
c := m.NewClient("", "", "")
|
||||
defer c.Close()
|
||||
require.NoError(t, m.Ping(context.Background()))
|
||||
|
||||
_, err := c.GetAddresses(context.Background())
|
||||
require.Error(t, err)
|
||||
// The first call should fail because the Retry-After header is invalid.
|
||||
// The second call should succeed.
|
||||
require.Len(t, calls, 2)
|
||||
|
||||
// The second call should be made at least 10 seconds after the first call.
|
||||
require.True(t, calls[1].After(calls[0].Add(10*time.Second)))
|
||||
}
|
||||
|
||||
func TestHandleUnprocessableEntity(t *testing.T) {
|
||||
var numCalls int
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
numCalls++
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
}))
|
||||
@@ -180,7 +160,7 @@ func TestHandleUnprocessableEntity(t *testing.T) {
|
||||
func TestHandleDialFailure(t *testing.T) {
|
||||
var numCalls int
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
numCalls++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -210,7 +190,7 @@ func TestHandleDialFailure(t *testing.T) {
|
||||
func TestHandleTooManyDialFailures(t *testing.T) {
|
||||
var numCalls int
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
numCalls++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
@@ -242,7 +222,7 @@ func TestHandleTooManyDialFailures(t *testing.T) {
|
||||
func TestRetriesWithContextTimeout(t *testing.T) {
|
||||
var numCalls int
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
numCalls++
|
||||
|
||||
if numCalls < 5 {
|
||||
@@ -276,7 +256,7 @@ func TestRetriesWithContextTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestReturnErrNoConnection(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer ts.Close()
|
||||
@@ -305,7 +285,7 @@ func TestStatusCallbacks(t *testing.T) {
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
statusCh := make(chan proton.Status, 1)
|
||||
|
||||
398
netctl.go
Normal file
398
netctl.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package proton
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// InsecureTransport returns an http.Transport with InsecureSkipVerify set to true.
|
||||
func InsecureTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
}
|
||||
|
||||
// ctl can be used to control whether a dialer can dial, and whether the resulting
|
||||
// connection can read or write.
|
||||
type NetCtl struct {
|
||||
canDial bool
|
||||
dialLimit uint64
|
||||
dialCount uint64
|
||||
onDial []func(net.Conn)
|
||||
dlock sync.RWMutex
|
||||
|
||||
canRead bool
|
||||
readLimit uint64
|
||||
readCount uint64
|
||||
readSpeed int
|
||||
onRead []func([]byte)
|
||||
rlock sync.RWMutex
|
||||
|
||||
canWrite bool
|
||||
writeLimit uint64
|
||||
writeCount uint64
|
||||
writeSpeed int
|
||||
onWrite []func([]byte)
|
||||
wlock sync.RWMutex
|
||||
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
// NewNetCtl returns a new ctl with all fields set to true.
|
||||
func NewNetCtl() *NetCtl {
|
||||
return &NetCtl{
|
||||
canDial: true,
|
||||
canRead: true,
|
||||
canWrite: true,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCanDial sets whether the dialer can dial.
|
||||
func (c *NetCtl) SetCanDial(canDial bool) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
c.canDial = canDial
|
||||
}
|
||||
|
||||
// SetDialLimit sets the maximum number of times dialers using this controller can dial.
|
||||
func (c *NetCtl) SetDialLimit(limit uint64) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
c.dialLimit = limit
|
||||
}
|
||||
|
||||
// SetCanRead sets whether the connection can read.
|
||||
func (c *NetCtl) SetCanRead(canRead bool) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
for _, conn := range c.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
c.rlock.Lock()
|
||||
defer c.rlock.Unlock()
|
||||
|
||||
c.canRead = canRead
|
||||
}
|
||||
|
||||
// SetReadLimit sets the maximum number of bytes that can be read.
|
||||
func (c *NetCtl) SetReadLimit(limit uint64) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
for _, conn := range c.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
c.rlock.Lock()
|
||||
defer c.rlock.Unlock()
|
||||
|
||||
c.readLimit = limit
|
||||
c.readCount = 0
|
||||
}
|
||||
|
||||
// SetReadSpeed sets the maximum number of bytes that can be read per second.
|
||||
func (c *NetCtl) SetReadSpeed(speed int) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
for _, conn := range c.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
c.rlock.Lock()
|
||||
defer c.rlock.Unlock()
|
||||
|
||||
c.readSpeed = speed
|
||||
}
|
||||
|
||||
// SetCanWrite sets whether the connection can write.
|
||||
func (c *NetCtl) SetCanWrite(canWrite bool) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
for _, conn := range c.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
c.wlock.Lock()
|
||||
defer c.wlock.Unlock()
|
||||
|
||||
c.canWrite = canWrite
|
||||
}
|
||||
|
||||
// SetWriteLimit sets the maximum number of bytes that can be written.
|
||||
func (c *NetCtl) SetWriteLimit(limit uint64) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
for _, conn := range c.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
c.wlock.Lock()
|
||||
defer c.wlock.Unlock()
|
||||
|
||||
c.writeLimit = limit
|
||||
c.writeCount = 0
|
||||
}
|
||||
|
||||
// SetWriteSpeed sets the maximum number of bytes that can be written per second.
|
||||
func (c *NetCtl) SetWriteSpeed(speed int) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
for _, conn := range c.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
c.wlock.Lock()
|
||||
defer c.wlock.Unlock()
|
||||
|
||||
c.writeSpeed = speed
|
||||
}
|
||||
|
||||
// OnDial adds a callback that is called with the created connection when a dial is successful.
|
||||
func (c *NetCtl) OnDial(f func(net.Conn)) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
c.onDial = append(c.onDial, f)
|
||||
}
|
||||
|
||||
// OnRead adds a callback that is called with the read bytes when a read is successful.
|
||||
func (c *NetCtl) OnRead(fn func([]byte)) {
|
||||
c.rlock.Lock()
|
||||
defer c.rlock.Unlock()
|
||||
|
||||
c.onRead = append(c.onRead, fn)
|
||||
}
|
||||
|
||||
// OnWrite adds a callback that is called with the written bytes when a write is successful.
|
||||
func (c *NetCtl) OnWrite(fn func([]byte)) {
|
||||
c.wlock.Lock()
|
||||
defer c.wlock.Unlock()
|
||||
|
||||
c.onWrite = append(c.onWrite, fn)
|
||||
}
|
||||
|
||||
// Disable is equivalent to disallowing dial, read and write.
|
||||
func (c *NetCtl) Disable() {
|
||||
c.SetCanDial(false)
|
||||
c.SetCanRead(false)
|
||||
c.SetCanWrite(false)
|
||||
}
|
||||
|
||||
// Enable is equivalent to allowing dial, read and write.
|
||||
func (c *NetCtl) Enable() {
|
||||
c.SetCanDial(true)
|
||||
c.SetCanRead(true)
|
||||
c.SetCanWrite(true)
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer controlled by the ctl.
|
||||
func (c *NetCtl) NewRoundTripper(tlsConfig *tls.Config) http.RoundTripper {
|
||||
return &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return c.dial(ctx, &net.Dialer{}, network, addr)
|
||||
},
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return c.dial(ctx, &tls.Dialer{Config: tlsConfig}, network, addr)
|
||||
},
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// ctxDialer implements DialContext.
|
||||
type ctxDialer interface {
|
||||
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// dial dials using d, but only if the controller allows it.
|
||||
func (c *NetCtl) dial(ctx context.Context, dialer ctxDialer, network, addr string) (net.Conn, error) {
|
||||
c.dlock.Lock()
|
||||
defer c.dlock.Unlock()
|
||||
|
||||
if !c.canDial {
|
||||
return nil, errors.New("dial failed (not allowed)")
|
||||
}
|
||||
|
||||
if c.dialLimit > 0 && c.dialCount >= c.dialLimit {
|
||||
return nil, errors.New("dial failed (limit reached)")
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.dialCount++
|
||||
|
||||
for _, fn := range c.onDial {
|
||||
fn(conn)
|
||||
}
|
||||
|
||||
c.conns = append(c.conns, conn)
|
||||
|
||||
return newConn(conn, c), nil
|
||||
}
|
||||
|
||||
// read reads from r, but only if the controller allows it.
|
||||
func (c *NetCtl) read(r io.Reader, b []byte) (int, error) {
|
||||
c.rlock.Lock()
|
||||
defer c.rlock.Unlock()
|
||||
|
||||
if !c.canRead {
|
||||
return 0, errors.New("read failed (not allowed)")
|
||||
}
|
||||
|
||||
if c.readLimit > 0 && c.readCount >= c.readLimit {
|
||||
return 0, errors.New("read failed (limit reached)")
|
||||
}
|
||||
|
||||
var rem uint64
|
||||
|
||||
if c.readLimit > 0 && c.readLimit-c.readCount < uint64(len(b)) {
|
||||
rem = c.readLimit - c.readCount
|
||||
} else {
|
||||
rem = uint64(len(b))
|
||||
}
|
||||
|
||||
c.rlock.Unlock()
|
||||
n, err := newSlowReader(r, c.readSpeed).Read(b[:rem])
|
||||
c.rlock.Lock()
|
||||
|
||||
c.readCount += uint64(n)
|
||||
|
||||
for _, fn := range c.onRead {
|
||||
fn(b[:n])
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// write writes to w, but only if the controller allows it.
|
||||
func (c *NetCtl) write(w io.Writer, b []byte) (int, error) {
|
||||
c.wlock.Lock()
|
||||
defer c.wlock.Unlock()
|
||||
|
||||
if !c.canWrite {
|
||||
return 0, errors.New("write failed (not allowed)")
|
||||
}
|
||||
|
||||
if c.writeLimit > 0 && c.writeCount >= c.writeLimit {
|
||||
return 0, errors.New("write failed (limit exceeded)")
|
||||
}
|
||||
|
||||
var rem uint64
|
||||
|
||||
if c.writeLimit > 0 && c.writeLimit-c.writeCount < uint64(len(b)) {
|
||||
rem = c.writeLimit - c.writeCount
|
||||
} else {
|
||||
rem = uint64(len(b))
|
||||
}
|
||||
|
||||
c.wlock.Unlock()
|
||||
n, err := newSlowWriter(w, c.writeSpeed).Write(b[:rem])
|
||||
c.wlock.Lock()
|
||||
|
||||
c.writeCount += uint64(n)
|
||||
|
||||
for _, fn := range c.onWrite {
|
||||
fn(b[:n])
|
||||
}
|
||||
|
||||
if uint64(n) < rem {
|
||||
return n, fmt.Errorf("write incomplete (limit reached)")
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// conn is a wrapper around net.conn that can be used to control whether a connection can read or write.
|
||||
type conn struct {
|
||||
net.Conn
|
||||
|
||||
ctl *NetCtl
|
||||
}
|
||||
|
||||
func newConn(c net.Conn, ctl *NetCtl) *conn {
|
||||
return &conn{
|
||||
Conn: c,
|
||||
ctl: ctl,
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from the wrapped connection, but only if the controller allows it.
|
||||
func (c *conn) Read(b []byte) (int, error) {
|
||||
return c.ctl.read(c.Conn, b)
|
||||
}
|
||||
|
||||
// Write writes to the wrapped connection, but only if the controller allows it.
|
||||
func (c *conn) Write(b []byte) (int, error) {
|
||||
return c.ctl.write(c.Conn, b)
|
||||
}
|
||||
|
||||
// slowReader is an io.Reader that reads at a fixed rate.
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
|
||||
// bytesPerSec is the number of bytes to read per second.
|
||||
bytesPerSec int
|
||||
}
|
||||
|
||||
func newSlowReader(r io.Reader, bytesPerSec int) *slowReader {
|
||||
return &slowReader{
|
||||
r: r,
|
||||
bytesPerSec: bytesPerSec,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *slowReader) Read(b []byte) (int, error) {
|
||||
start := time.Now()
|
||||
|
||||
n, err := r.r.Read(b)
|
||||
|
||||
if r.bytesPerSec > 0 {
|
||||
time.Sleep(time.Until(start.Add(time.Duration(n*r.bytesPerSec) * time.Second)))
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// slowWriter is an io.Writer that writes at a fixed rate.
|
||||
type slowWriter struct {
|
||||
w io.Writer
|
||||
|
||||
// bytesPerSec is the number of bytes to write per second.
|
||||
bytesPerSec int
|
||||
}
|
||||
|
||||
func newSlowWriter(w io.Writer, bytesPerSec int) *slowWriter {
|
||||
return &slowWriter{
|
||||
w: w,
|
||||
bytesPerSec: bytesPerSec,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *slowWriter) Write(b []byte) (int, error) {
|
||||
start := time.Now()
|
||||
|
||||
n, err := w.w.Write(b)
|
||||
|
||||
if w.bytesPerSec > 0 {
|
||||
time.Sleep(time.Until(start.Add(time.Duration(n*w.bytesPerSec) * time.Second)))
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
func TestNetCtl_ReadLimit(t *testing.T) {
|
||||
// Create a test http server that writes 100 bytes.
|
||||
// Including the header, this is 217 bytes (100 bytes + 117 bytes).
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
if _, err := w.Write(make([]byte, 100)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -22,16 +22,16 @@ func TestNetCtl_ReadLimit(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
// Create a new net controller.
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
|
||||
ctl.SetReadLimit(300)
|
||||
|
||||
// Create a new http client with the dialer.
|
||||
client := &http.Client{
|
||||
Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||
Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}),
|
||||
}
|
||||
|
||||
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
|
||||
netCtl.SetReadLimit(300)
|
||||
|
||||
// This should succeed.
|
||||
if resp, err := client.Get(ts.URL); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -47,7 +47,7 @@ func TestNetCtl_ReadLimit(t *testing.T) {
|
||||
|
||||
func TestNetCtl_WriteLimit(t *testing.T) {
|
||||
// Create a test http server that reads the given body.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
if _, err := io.ReadAll(r.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -55,16 +55,16 @@ func TestNetCtl_WriteLimit(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
// Create a new net controller.
|
||||
netCtl := proton.NewNetCtl()
|
||||
ctl := proton.NewNetCtl()
|
||||
|
||||
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
|
||||
ctl.SetWriteLimit(300)
|
||||
|
||||
// Create a new http client with the dialer.
|
||||
client := &http.Client{
|
||||
Transport: proton.NewDialer(netCtl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper(),
|
||||
Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}),
|
||||
}
|
||||
|
||||
// Set the read limit to 300 bytes -- the first request should succeed, the second should fail.
|
||||
netCtl.SetWriteLimit(300)
|
||||
|
||||
// This should succeed.
|
||||
if resp, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err != nil {
|
||||
t.Fatal(err)
|
||||
50
response.go
50
response.go
@@ -62,37 +62,31 @@ func updateTime(_ *resty.Client, res *resty.Response) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec
|
||||
func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) {
|
||||
if res.StatusCode() == http.StatusTooManyRequests {
|
||||
if after := res.Header().Get("Retry-After"); after != "" {
|
||||
l := logrus.WithFields(logrus.Fields{
|
||||
"pkg": "go-proton-api",
|
||||
"statusCode": res.StatusCode(),
|
||||
"url": res.Request.URL,
|
||||
"verb": res.Request.Method,
|
||||
})
|
||||
|
||||
seconds, err := strconv.Atoi(after)
|
||||
if err != nil {
|
||||
l.WithField("after", after).WithError(err).Warning(
|
||||
"Cannot convert Retry-After to a number, continue with default 10 second cooldown.",
|
||||
)
|
||||
seconds = 10
|
||||
}
|
||||
|
||||
// To avoid spikes when all clients retry at the same time, we add some random wait.
|
||||
seconds += rand.Intn(10) //nolint:gosec // It is OK to use weak random number generator here.
|
||||
l = l.WithField("delay", seconds)
|
||||
|
||||
// Maximum retry time in client is is one minute. But
|
||||
// here wait times can be longer e.g. high API load
|
||||
l.Warn("Delay the retry after http response")
|
||||
return time.Duration(seconds) * time.Second, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 0 and no error means default behaviour which is exponential backoff with jitter.
|
||||
if res.StatusCode() != http.StatusTooManyRequests {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Parse the Retry-After header, or fallback to 10 seconds.
|
||||
after, err := strconv.Atoi(res.Header().Get("Retry-After"))
|
||||
if err != nil {
|
||||
after = 10
|
||||
}
|
||||
|
||||
// Add some jitter to the delay.
|
||||
after += rand.Intn(10)
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"pkg": "go-proton-api",
|
||||
"status": res.StatusCode(),
|
||||
"url": res.Request.URL,
|
||||
"method": res.Request.Method,
|
||||
"after": after,
|
||||
}).Warn("Too many requests, retrying after delay")
|
||||
|
||||
return time.Duration(after) * time.Second, nil
|
||||
}
|
||||
|
||||
func catchTooManyRequests(res *resty.Response, _ error) bool {
|
||||
|
||||
25
server/backend/crypto_fast.go
Normal file
25
server/backend/crypto_fast.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package backend
|
||||
|
||||
import "github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
|
||||
var preCompKey *crypto.Key
|
||||
|
||||
func init() {
|
||||
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
preCompKey = key
|
||||
}
|
||||
|
||||
// FastGenerateKey is a fast version of GenerateKey that uses a pre-computed key.
|
||||
// This is useful for testing but is incredibly insecure.
|
||||
func FastGenerateKey(_, _ string, passphrase []byte, _ string, _ int) (string, error) {
|
||||
encKey, err := preCompKey.Lock(passphrase)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return encKey.Armor()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Call struct {
|
||||
@@ -10,6 +11,9 @@ type Call struct {
|
||||
Method string
|
||||
Status int
|
||||
|
||||
Time time.Time
|
||||
Duration time.Duration
|
||||
|
||||
RequestHeader http.Header
|
||||
RequestBody []byte
|
||||
|
||||
|
||||
@@ -1,22 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/go-proton-api/server/backend"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
)
|
||||
import "github.com/ProtonMail/go-proton-api/server/backend"
|
||||
|
||||
func init() {
|
||||
key, err := crypto.GenerateKey("name", "email", "rsa", 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backend.GenerateKey = func(_, _ string, passphrase []byte, _ string, _ int) (string, error) {
|
||||
encKey, err := key.Lock(passphrase)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return encKey.Armor()
|
||||
}
|
||||
backend.GenerateKey = backend.FastGenerateKey
|
||||
}
|
||||
|
||||
51
server/rate_limit.go
Normal file
51
server/rate_limit.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// rateLimiter is a rate limiter for the server.
|
||||
// If more than limit requests are made in the time window, the server will return 429.
|
||||
type rateLimiter struct {
|
||||
// limit is the rate limit to apply to the server.
|
||||
limit int
|
||||
|
||||
// window is the window in which to apply the rate limit.
|
||||
window time.Duration
|
||||
|
||||
// nextReset is the time at which the rate limit will reset.
|
||||
nextReset time.Time
|
||||
|
||||
// count is the number of calls made to the server.
|
||||
count int
|
||||
|
||||
// countLock is a mutex for the callCount.
|
||||
countLock sync.Mutex
|
||||
}
|
||||
|
||||
func newRateLimiter(limit int, window time.Duration) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// exceeded checks the rate limit and returns how long to wait before the next request.
|
||||
func (r *rateLimiter) exceeded() time.Duration {
|
||||
r.countLock.Lock()
|
||||
defer r.countLock.Unlock()
|
||||
|
||||
if time.Now().After(r.nextReset) {
|
||||
r.count = 0
|
||||
r.nextReset = time.Now().Add(r.window)
|
||||
}
|
||||
|
||||
r.count++
|
||||
|
||||
if r.count > r.limit {
|
||||
return time.Until(r.nextReset)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +21,7 @@ func initRouter(s *Server) {
|
||||
s.r.Use(
|
||||
s.requireValidAppVersion(),
|
||||
s.setSessionCookie(),
|
||||
s.applyRateLimit(),
|
||||
)
|
||||
|
||||
if core := s.r.Group("/core/v4"); core != nil {
|
||||
@@ -171,8 +173,23 @@ func (s *Server) setSessionCookie() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) applyRateLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if s.rateLimit == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if wait := s.rateLimit.exceeded(); wait > 0 {
|
||||
c.Header("Retry-After", strconv.Itoa(int(wait.Seconds())))
|
||||
c.AbortWithStatus(http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) logCalls() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
req, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -199,6 +216,9 @@ func (s *Server) logCalls() gin.HandlerFunc {
|
||||
Method: c.Request.Method,
|
||||
Status: c.Writer.Status(),
|
||||
|
||||
Time: start,
|
||||
Duration: time.Since(start),
|
||||
|
||||
RequestHeader: c.Request.Header,
|
||||
RequestBody: req,
|
||||
|
||||
|
||||
@@ -47,6 +47,9 @@ type Server struct {
|
||||
|
||||
// offline is whether to pretend the server is offline and return 5xx errors.
|
||||
offline bool
|
||||
|
||||
// rateLimit is the rate limiter for the server.
|
||||
rateLimit *rateLimiter
|
||||
}
|
||||
|
||||
func New(opts ...Option) *Server {
|
||||
@@ -69,7 +72,7 @@ func (s *Server) GetProxyURL() string {
|
||||
return s.s.URL + "/proxy"
|
||||
}
|
||||
|
||||
// GetDomain returns the domain of the server.
|
||||
// GetDomain returns the domain of the server (e.g. "proton.local").
|
||||
func (s *Server) GetDomain() string {
|
||||
return s.domain
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ type serverBuilder struct {
|
||||
logger io.Writer
|
||||
origin string
|
||||
cacher AuthCacher
|
||||
rateLimiter *rateLimiter
|
||||
}
|
||||
|
||||
func newServerBuilder() *serverBuilder {
|
||||
@@ -46,6 +47,7 @@ func (builder *serverBuilder) build() *Server {
|
||||
domain: builder.domain,
|
||||
proxyOrigin: builder.origin,
|
||||
authCacher: builder.cacher,
|
||||
rateLimit: builder.rateLimiter,
|
||||
}
|
||||
|
||||
if builder.withTLS {
|
||||
@@ -143,3 +145,19 @@ type withAuthCache struct {
|
||||
func (opt withAuthCache) config(builder *serverBuilder) {
|
||||
builder.cacher = opt.cacher
|
||||
}
|
||||
|
||||
func WithRateLimit(limit int, window time.Duration) Option {
|
||||
return &withRateLimit{
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
type withRateLimit struct {
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
func (opt withRateLimit) config(builder *serverBuilder) {
|
||||
builder.rateLimiter = newRateLimiter(opt.limit, opt.window)
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ func TestServer_Ping(t *testing.T) {
|
||||
|
||||
m := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
|
||||
var status proton.Status
|
||||
@@ -1313,7 +1313,7 @@ func TestServer_Messages_Fetch(t *testing.T) {
|
||||
|
||||
mm := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
defer mm.Close()
|
||||
|
||||
@@ -1355,7 +1355,7 @@ func TestServer_Messages_Status(t *testing.T) {
|
||||
|
||||
mm := proton.New(
|
||||
proton.WithHostURL(s.GetHostURL()),
|
||||
proton.WithTransport(proton.NewDialer(ctl, &tls.Config{InsecureSkipVerify: true}).GetRoundTripper()),
|
||||
proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})),
|
||||
)
|
||||
defer mm.Close()
|
||||
|
||||
|
||||
10
unlock.go
10
unlock.go
@@ -12,6 +12,8 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi
|
||||
userKR, err := user.Keys.Unlock(saltedKeyPass, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unlock user keys: %w", err)
|
||||
} else if userKR.CountDecryptionEntities() == 0 {
|
||||
return nil, nil, fmt.Errorf("failed to unlock any user keys")
|
||||
}
|
||||
|
||||
addrKRs := make(map[string]*crypto.KeyRing)
|
||||
@@ -19,9 +21,13 @@ func Unlock(user User, addresses []Address, saltedKeyPass []byte) (*crypto.KeyRi
|
||||
for idx, addrKR := range parallel.Map(runtime.NumCPU(), addresses, func(addr Address) *crypto.KeyRing {
|
||||
return addr.Keys.TryUnlock(saltedKeyPass, userKR)
|
||||
}) {
|
||||
if addrKR != nil {
|
||||
addrKRs[addresses[idx].ID] = addrKR
|
||||
if addrKR.CountDecryptionEntities() == 0 {
|
||||
continue
|
||||
} else if addrKR == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
addrKRs[addresses[idx].ID] = addrKR
|
||||
}
|
||||
|
||||
if len(addrKRs) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user