mirror of
https://github.com/kopia/kopia.git
synced 2026-05-19 12:14:45 -04:00
Merge github.com:kopia/repo into import-repo
This commit is contained in:
9
internal/repologging/logging.go
Normal file
9
internal/repologging/logging.go
Normal file
@@ -0,0 +1,9 @@
|
||||
// Package repologging provides loggers.
|
||||
package repologging
|
||||
|
||||
import "github.com/op/go-logging"
|
||||
|
||||
// Logger returns an instance of a logger used throughout repository codebase.
|
||||
func Logger(module string) *logging.Logger {
|
||||
return logging.MustGetLogger(module)
|
||||
}
|
||||
@@ -23,6 +23,7 @@ type Environment struct {
|
||||
|
||||
configDir string
|
||||
storageDir string
|
||||
connected bool
|
||||
}
|
||||
|
||||
// Setup sets up a test environment.
|
||||
@@ -75,6 +76,8 @@ func (e *Environment) Setup(t *testing.T, opts ...func(*repo.NewRepositoryOption
|
||||
t.Fatalf("can't connect: %v", err)
|
||||
}
|
||||
|
||||
e.connected = true
|
||||
|
||||
e.Repository, err = repo.Open(ctx, e.configFile(), masterPassword, &repo.Options{})
|
||||
if err != nil {
|
||||
t.Fatalf("can't open: %v", err)
|
||||
@@ -88,8 +91,13 @@ func (e *Environment) Close(t *testing.T) {
|
||||
if err := e.Repository.Close(context.Background()); err != nil {
|
||||
t.Fatalf("unable to close: %v", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(e.configDir); err != nil {
|
||||
if e.connected {
|
||||
if err := repo.Disconnect(e.configFile()); err != nil {
|
||||
t.Errorf("error disconnecting: %v", err)
|
||||
}
|
||||
}
|
||||
if err := os.Remove(e.configDir); err != nil {
|
||||
// should be empty, assuming Disconnect was successful
|
||||
t.Errorf("error removing config directory: %v", err)
|
||||
}
|
||||
if err := os.RemoveAll(e.storageDir); err != nil {
|
||||
|
||||
44
internal/retry/retry.go
Normal file
44
internal/retry/retry.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Package retry implements exponential retry policy.
|
||||
package retry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/kopia/repo/internal/repologging"
|
||||
)
|
||||
|
||||
var log = repologging.Logger("repo/retry")
|
||||
|
||||
var (
|
||||
maxAttempts = 10
|
||||
retryInitialSleepAmount = 1 * time.Second
|
||||
retryMaxSleepAmount = 32 * time.Second
|
||||
)
|
||||
|
||||
// AttemptFunc performs an attempt and returns a value (optional, may be nil) and an error.
|
||||
type AttemptFunc func() (interface{}, error)
|
||||
|
||||
// IsRetriableFunc is a function that determines whether an error is retriable.
|
||||
type IsRetriableFunc func(err error) bool
|
||||
|
||||
// WithExponentialBackoff runs the provided attempt until it succeeds, retrying on all errors that are
|
||||
// deemed retriable by the provided function. The delay between retries grows exponentially up to
|
||||
// a certain limit.
|
||||
func WithExponentialBackoff(desc string, attempt AttemptFunc, isRetriableError IsRetriableFunc) (interface{}, error) {
|
||||
sleepAmount := retryInitialSleepAmount
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
v, err := attempt()
|
||||
if !isRetriableError(err) {
|
||||
return v, err
|
||||
}
|
||||
log.Debugf("got error %v when %v (#%v), sleeping for %v before retrying", err, desc, i, sleepAmount)
|
||||
time.Sleep(sleepAmount)
|
||||
sleepAmount *= 2
|
||||
if sleepAmount > retryMaxSleepAmount {
|
||||
sleepAmount = retryMaxSleepAmount
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to complete %v despite %v retries", desc, maxAttempts)
|
||||
}
|
||||
59
internal/retry/retry_test.go
Normal file
59
internal/retry/retry_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errRetriable = errors.New("retriable")
|
||||
)
|
||||
|
||||
func isRetriable(e error) bool {
|
||||
return e == errRetriable
|
||||
}
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
retryInitialSleepAmount = 10 * time.Millisecond
|
||||
retryMaxSleepAmount = 20 * time.Millisecond
|
||||
maxAttempts = 3
|
||||
|
||||
cnt := 0
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
f func() (interface{}, error)
|
||||
want interface{}
|
||||
wantError error
|
||||
}{
|
||||
{"success-nil", func() (interface{}, error) { return nil, nil }, nil, nil},
|
||||
{"success", func() (interface{}, error) { return 3, nil }, 3, nil},
|
||||
{"retriable-succeeds", func() (interface{}, error) {
|
||||
cnt++
|
||||
if cnt < 2 {
|
||||
return nil, errRetriable
|
||||
}
|
||||
return 4, nil
|
||||
}, 4, nil},
|
||||
{"retriable-never-succeeds", func() (interface{}, error) { return nil, errRetriable }, nil, fmt.Errorf("unable to complete retriable-never-succeeds despite 3 retries")},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tc := tc
|
||||
t.Parallel()
|
||||
|
||||
got, err := WithExponentialBackoff(tc.desc, tc.f, isRetriable)
|
||||
if !reflect.DeepEqual(err, tc.wantError) {
|
||||
t.Errorf("invalid error %q, wanted %q", err, tc.wantError)
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Errorf("invalid value %v, wanted %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
110
internal/storagetesting/asserts.go
Normal file
110
internal/storagetesting/asserts.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package storagetesting
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/kopia/repo/storage"
|
||||
)
|
||||
|
||||
// AssertGetBlock asserts that the specified storage block has correct content.
|
||||
func AssertGetBlock(ctx context.Context, t *testing.T, s storage.Storage, block string, expected []byte) {
|
||||
t.Helper()
|
||||
|
||||
b, err := s.GetBlock(ctx, block, 0, -1)
|
||||
if err != nil {
|
||||
t.Errorf("GetBlock(%v) returned error %v, expected data: %v", block, err, expected)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(b, expected) {
|
||||
t.Errorf("GetBlock(%v) returned %x, but expected %x", block, b, expected)
|
||||
}
|
||||
|
||||
half := int64(len(expected) / 2)
|
||||
if half == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
b, err = s.GetBlock(ctx, block, 0, 0)
|
||||
if err != nil {
|
||||
t.Errorf("GetBlock(%v) returned error %v, expected data: %v", block, err, expected)
|
||||
return
|
||||
}
|
||||
|
||||
if len(b) != 0 {
|
||||
t.Errorf("GetBlock(%v) returned non-zero length: %v", block, len(b))
|
||||
return
|
||||
}
|
||||
|
||||
b, err = s.GetBlock(ctx, block, 0, half)
|
||||
if err != nil {
|
||||
t.Errorf("GetBlock(%v) returned error %v, expected data: %v", block, err, expected)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(b, expected[0:half]) {
|
||||
t.Errorf("GetBlock(%v) returned %x, but expected %x", block, b, expected[0:half])
|
||||
}
|
||||
|
||||
b, err = s.GetBlock(ctx, block, half, int64(len(expected))-half)
|
||||
if err != nil {
|
||||
t.Errorf("GetBlock(%v) returned error %v, expected data: %v", block, err, expected)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(b, expected[len(expected)-int(half):]) {
|
||||
t.Errorf("GetBlock(%v) returned %x, but expected %x", block, b, expected[len(expected)-int(half):])
|
||||
}
|
||||
|
||||
AssertInvalidOffsetLength(ctx, t, s, block, -3, 1)
|
||||
AssertInvalidOffsetLength(ctx, t, s, block, int64(len(expected)), 3)
|
||||
AssertInvalidOffsetLength(ctx, t, s, block, int64(len(expected)-1), 3)
|
||||
AssertInvalidOffsetLength(ctx, t, s, block, int64(len(expected)+1), 3)
|
||||
}
|
||||
|
||||
// AssertInvalidOffsetLength verifies that the given combination of (offset,length) fails on GetBlock()
|
||||
func AssertInvalidOffsetLength(ctx context.Context, t *testing.T, s storage.Storage, block string, offset, length int64) {
|
||||
if _, err := s.GetBlock(ctx, block, offset, length); err == nil {
|
||||
t.Errorf("GetBlock(%v,%v,%v) did not return error for invalid offset/length", block, offset, length)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertGetBlockNotFound asserts that GetBlock() for specified storage block returns ErrBlockNotFound.
|
||||
func AssertGetBlockNotFound(ctx context.Context, t *testing.T, s storage.Storage, block string) {
|
||||
t.Helper()
|
||||
|
||||
b, err := s.GetBlock(ctx, block, 0, -1)
|
||||
if err != storage.ErrBlockNotFound || b != nil {
|
||||
t.Errorf("GetBlock(%v) returned %v, %v but expected ErrBlockNotFound", block, b, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AssertListResults asserts that the list results with given prefix return the specified list of names in order.
|
||||
func AssertListResults(ctx context.Context, t *testing.T, s storage.Storage, prefix string, want ...string) {
|
||||
t.Helper()
|
||||
var names []string
|
||||
|
||||
if err := s.ListBlocks(ctx, prefix, func(e storage.BlockMetadata) error {
|
||||
names = append(names, e.BlockID)
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
names = sorted(names)
|
||||
want = sorted(want)
|
||||
|
||||
if !reflect.DeepEqual(names, want) {
|
||||
t.Errorf("ListBlocks(%v) returned %v, but wanted %v", prefix, names, want)
|
||||
}
|
||||
}
|
||||
|
||||
func sorted(s []string) []string {
|
||||
x := append([]string(nil), s...)
|
||||
sort.Strings(x)
|
||||
return x
|
||||
}
|
||||
2
internal/storagetesting/doc.go
Normal file
2
internal/storagetesting/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package storagetesting is used for testing Storage implementations.
|
||||
package storagetesting
|
||||
115
internal/storagetesting/faulty.go
Normal file
115
internal/storagetesting/faulty.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package storagetesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kopia/repo/internal/repologging"
|
||||
"github.com/kopia/repo/storage"
|
||||
)
|
||||
|
||||
var log = repologging.Logger("faulty-storage")
|
||||
|
||||
// Fault describes the behavior of a single fault.
|
||||
type Fault struct {
|
||||
Repeat int // how many times to repeat this fault
|
||||
Sleep time.Duration // sleep before returning
|
||||
ErrCallback func() error
|
||||
WaitFor chan struct{} // waits until the given channel is closed before returning
|
||||
Err error // error to return (can be nil in combination with Sleep and WaitFor)
|
||||
}
|
||||
|
||||
// FaultyStorage implements fault injection for Storage.
|
||||
type FaultyStorage struct {
|
||||
Base storage.Storage
|
||||
Faults map[string][]*Fault
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// GetBlock implements storage.Storage
|
||||
func (s *FaultyStorage) GetBlock(ctx context.Context, id string, offset, length int64) ([]byte, error) {
|
||||
if err := s.getNextFault("GetBlock", id, offset, length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.Base.GetBlock(ctx, id, offset, length)
|
||||
}
|
||||
|
||||
// PutBlock implements storage.Storage
|
||||
func (s *FaultyStorage) PutBlock(ctx context.Context, id string, data []byte) error {
|
||||
if err := s.getNextFault("PutBlock", id, len(data)); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Base.PutBlock(ctx, id, data)
|
||||
}
|
||||
|
||||
// DeleteBlock implements storage.Storage
|
||||
func (s *FaultyStorage) DeleteBlock(ctx context.Context, id string) error {
|
||||
if err := s.getNextFault("DeleteBlock", id); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Base.DeleteBlock(ctx, id)
|
||||
}
|
||||
|
||||
// ListBlocks implements storage.Storage
|
||||
func (s *FaultyStorage) ListBlocks(ctx context.Context, prefix string, callback func(storage.BlockMetadata) error) error {
|
||||
if err := s.getNextFault("ListBlocks", prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.Base.ListBlocks(ctx, prefix, func(bm storage.BlockMetadata) error {
|
||||
if err := s.getNextFault("ListBlocksItem", prefix); err != nil {
|
||||
return err
|
||||
}
|
||||
return callback(bm)
|
||||
})
|
||||
}
|
||||
|
||||
// Close implements storage.Storage
|
||||
func (s *FaultyStorage) Close(ctx context.Context) error {
|
||||
if err := s.getNextFault("Close"); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Base.Close(ctx)
|
||||
}
|
||||
|
||||
// ConnectionInfo implements storage.Storage
|
||||
func (s *FaultyStorage) ConnectionInfo() storage.ConnectionInfo {
|
||||
return s.Base.ConnectionInfo()
|
||||
}
|
||||
|
||||
func (s *FaultyStorage) getNextFault(method string, args ...interface{}) error {
|
||||
s.mu.Lock()
|
||||
faults := s.Faults[method]
|
||||
if len(faults) == 0 {
|
||||
s.mu.Unlock()
|
||||
log.Debugf("no faults for %v %v", method, args)
|
||||
return nil
|
||||
}
|
||||
|
||||
f := faults[0]
|
||||
if f.Repeat > 0 {
|
||||
f.Repeat--
|
||||
log.Debugf("will repeat %v more times the fault for %v %v", f.Repeat, method, args)
|
||||
} else {
|
||||
s.Faults[method] = faults[1:]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
if f.WaitFor != nil {
|
||||
log.Debugf("waiting for channel to be closed in %v %v", method, args)
|
||||
<-f.WaitFor
|
||||
}
|
||||
if f.Sleep > 0 {
|
||||
log.Debugf("sleeping for %v in %v %v", f.Sleep, method, args)
|
||||
}
|
||||
if f.ErrCallback != nil {
|
||||
err := f.ErrCallback()
|
||||
log.Debugf("returning %v for %v %v", err, method, args)
|
||||
return err
|
||||
}
|
||||
log.Debugf("returning %v for %v %v", f.Err, method, args)
|
||||
return f.Err
|
||||
}
|
||||
|
||||
var _ storage.Storage = (*FaultyStorage)(nil)
|
||||
133
internal/storagetesting/map.go
Normal file
133
internal/storagetesting/map.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package storagetesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kopia/repo/storage"
|
||||
)
|
||||
|
||||
type mapStorage struct {
|
||||
data map[string][]byte
|
||||
keyTime map[string]time.Time
|
||||
timeNow func() time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *mapStorage) GetBlock(ctx context.Context, id string, offset, length int64) ([]byte, error) {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
data, ok := s.data[id]
|
||||
if ok {
|
||||
data = append([]byte(nil), data...)
|
||||
if length < 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
if int(offset) > len(data) || offset < 0 {
|
||||
return nil, errors.New("invalid offset")
|
||||
}
|
||||
|
||||
data = data[offset:]
|
||||
if int(length) > len(data) {
|
||||
return nil, errors.New("invalid length")
|
||||
}
|
||||
return data[0:length], nil
|
||||
}
|
||||
|
||||
return nil, storage.ErrBlockNotFound
|
||||
}
|
||||
|
||||
func (s *mapStorage) PutBlock(ctx context.Context, id string, data []byte) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if _, ok := s.data[id]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.keyTime[id] = s.timeNow()
|
||||
s.data[id] = append([]byte{}, data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mapStorage) DeleteBlock(ctx context.Context, id string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
delete(s.data, id)
|
||||
delete(s.keyTime, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mapStorage) ListBlocks(ctx context.Context, prefix string, callback func(storage.BlockMetadata) error) error {
|
||||
s.mutex.RLock()
|
||||
|
||||
keys := []string{}
|
||||
for k := range s.data {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
}
|
||||
s.mutex.RUnlock()
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
s.mutex.RLock()
|
||||
v, ok := s.data[k]
|
||||
ts := s.keyTime[k]
|
||||
s.mutex.RUnlock()
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := callback(storage.BlockMetadata{
|
||||
BlockID: k,
|
||||
Length: int64(len(v)),
|
||||
Timestamp: ts,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mapStorage) Close(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mapStorage) TouchBlock(ctx context.Context, blockID string, threshold time.Duration) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if v, ok := s.keyTime[blockID]; ok {
|
||||
n := s.timeNow()
|
||||
if n.Sub(v) >= threshold {
|
||||
s.keyTime[blockID] = n
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mapStorage) ConnectionInfo() storage.ConnectionInfo {
|
||||
// unsupported
|
||||
return storage.ConnectionInfo{}
|
||||
}
|
||||
|
||||
// NewMapStorage returns an implementation of Storage backed by the contents of given map.
|
||||
// Used primarily for testing.
|
||||
func NewMapStorage(data map[string][]byte, keyTime map[string]time.Time, timeNow func() time.Time) storage.Storage {
|
||||
if keyTime == nil {
|
||||
keyTime = make(map[string]time.Time)
|
||||
}
|
||||
if timeNow == nil {
|
||||
timeNow = time.Now
|
||||
}
|
||||
return &mapStorage{data: data, keyTime: keyTime, timeNow: timeNow}
|
||||
}
|
||||
15
internal/storagetesting/map_test.go
Normal file
15
internal/storagetesting/map_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package storagetesting
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMapStorage(t *testing.T) {
|
||||
data := map[string][]byte{}
|
||||
r := NewMapStorage(data, nil, nil)
|
||||
if r == nil {
|
||||
t.Errorf("unexpected result: %v", r)
|
||||
}
|
||||
VerifyStorage(context.Background(), t, r)
|
||||
}
|
||||
84
internal/storagetesting/verify.go
Normal file
84
internal/storagetesting/verify.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package storagetesting
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/kopia/repo/storage"
|
||||
)
|
||||
|
||||
// VerifyStorage verifies the behavior of the specified storage.
|
||||
func VerifyStorage(ctx context.Context, t *testing.T, r storage.Storage) {
|
||||
blocks := []struct {
|
||||
blk string
|
||||
contents []byte
|
||||
}{
|
||||
{blk: string("abcdbbf4f0507d054ed5a80a5b65086f602b"), contents: []byte{}},
|
||||
{blk: string("zxce0e35630770c54668a8cfb4e414c6bf8f"), contents: []byte{1}},
|
||||
{blk: string("abff4585856ebf0748fd989e1dd623a8963d"), contents: bytes.Repeat([]byte{1}, 1000)},
|
||||
{blk: string("abgc3dca496d510f492c858a2df1eb824e62"), contents: bytes.Repeat([]byte{1}, 10000)},
|
||||
{blk: string("kopia.repository"), contents: bytes.Repeat([]byte{2}, 100)},
|
||||
}
|
||||
|
||||
// First verify that blocks don't exist.
|
||||
for _, b := range blocks {
|
||||
AssertGetBlockNotFound(ctx, t, r, b.blk)
|
||||
}
|
||||
|
||||
ctx2 := storage.WithUploadProgressCallback(ctx, func(desc string, completed, total int64) {
|
||||
log.Infof("progress %v: %v/%v", desc, completed, total)
|
||||
})
|
||||
|
||||
// Now add blocks.
|
||||
for _, b := range blocks {
|
||||
if err := r.PutBlock(ctx2, b.blk, b.contents); err != nil {
|
||||
t.Errorf("can't put block: %v", err)
|
||||
}
|
||||
|
||||
AssertGetBlock(ctx, t, r, b.blk, b.contents)
|
||||
}
|
||||
|
||||
AssertListResults(ctx, t, r, "", blocks[0].blk, blocks[1].blk, blocks[2].blk, blocks[3].blk, blocks[4].blk)
|
||||
AssertListResults(ctx, t, r, "ab", blocks[0].blk, blocks[2].blk, blocks[3].blk)
|
||||
|
||||
// Overwrite blocks.
|
||||
for _, b := range blocks {
|
||||
if err := r.PutBlock(ctx, b.blk, b.contents); err != nil {
|
||||
t.Errorf("can't put block: %v", err)
|
||||
}
|
||||
|
||||
AssertGetBlock(ctx, t, r, b.blk, b.contents)
|
||||
}
|
||||
|
||||
if err := r.DeleteBlock(ctx, blocks[0].blk); err != nil {
|
||||
t.Errorf("unable to delete block: %v", err)
|
||||
}
|
||||
if err := r.DeleteBlock(ctx, blocks[0].blk); err != nil {
|
||||
t.Errorf("invalid error when deleting deleted block: %v", err)
|
||||
}
|
||||
AssertListResults(ctx, t, r, "ab", blocks[2].blk, blocks[3].blk)
|
||||
AssertListResults(ctx, t, r, "", blocks[1].blk, blocks[2].blk, blocks[3].blk, blocks[4].blk)
|
||||
}
|
||||
|
||||
// AssertConnectionInfoRoundTrips verifies that the ConnectionInfo returned by a given storage can be used to create
|
||||
// equivalent storage
|
||||
func AssertConnectionInfoRoundTrips(ctx context.Context, t *testing.T, s storage.Storage) {
|
||||
t.Helper()
|
||||
|
||||
ci := s.ConnectionInfo()
|
||||
s2, err := storage.NewStorage(ctx, ci)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
ci2 := s2.ConnectionInfo()
|
||||
if !reflect.DeepEqual(ci, ci2) {
|
||||
t.Errorf("connection info does not round-trip: %v vs %v", ci, ci2)
|
||||
}
|
||||
|
||||
if err := s2.Close(ctx); err != nil {
|
||||
t.Errorf("unable to close storage: %v", err)
|
||||
}
|
||||
}
|
||||
44
internal/throttle/round_tripper.go
Normal file
44
internal/throttle/round_tripper.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package throttle
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type throttlerPool interface {
|
||||
AddReader(io.ReadCloser) (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
type throttlingRoundTripper struct {
|
||||
base http.RoundTripper
|
||||
downloadPool throttlerPool
|
||||
uploadPool throttlerPool
|
||||
}
|
||||
|
||||
func (rt *throttlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil && rt.uploadPool != nil {
|
||||
var err error
|
||||
req.Body, err = rt.uploadPool.AddReader(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
resp, err := rt.base.RoundTrip(req)
|
||||
if resp != nil && resp.Body != nil && rt.downloadPool != nil {
|
||||
resp.Body, err = rt.downloadPool.AddReader(resp.Body)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// NewRoundTripper returns http.RoundTripper that throttles upload and downloads.
|
||||
func NewRoundTripper(base http.RoundTripper, downloadPool throttlerPool, uploadPool throttlerPool) http.RoundTripper {
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &throttlingRoundTripper{
|
||||
base: base,
|
||||
downloadPool: downloadPool,
|
||||
uploadPool: uploadPool,
|
||||
}
|
||||
}
|
||||
103
internal/throttle/round_tripper_test.go
Normal file
103
internal/throttle/round_tripper_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package throttle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type baseRoundTripper struct {
|
||||
responses map[*http.Request]*http.Response
|
||||
}
|
||||
|
||||
func (rt *baseRoundTripper) add(req *http.Request, resp *http.Response) (*http.Request, *http.Response) {
|
||||
rt.responses[req] = resp
|
||||
return req, resp
|
||||
}
|
||||
|
||||
func (rt *baseRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp := rt.responses[req]
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("error occurred")
|
||||
}
|
||||
|
||||
type fakePool struct {
|
||||
readers []io.ReadCloser
|
||||
}
|
||||
|
||||
func (fp *fakePool) reset() {
|
||||
fp.readers = nil
|
||||
}
|
||||
|
||||
func (fp *fakePool) AddReader(r io.ReadCloser) (io.ReadCloser, error) {
|
||||
fp.readers = append(fp.readers, r)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func TestRoundTripper(t *testing.T) {
|
||||
downloadBody := ioutil.NopCloser(bytes.NewReader([]byte("data1")))
|
||||
uploadBody := ioutil.NopCloser(bytes.NewReader([]byte("data1")))
|
||||
|
||||
base := &baseRoundTripper{
|
||||
responses: make(map[*http.Request]*http.Response),
|
||||
}
|
||||
downloadPool := &fakePool{}
|
||||
uploadPool := &fakePool{}
|
||||
rt := NewRoundTripper(base, downloadPool, uploadPool)
|
||||
|
||||
// Empty request (no request, no response)
|
||||
uploadPool.reset()
|
||||
downloadPool.reset()
|
||||
req1, resp1 := base.add(&http.Request{}, &http.Response{})
|
||||
resp, err := rt.RoundTrip(req1)
|
||||
if resp != resp1 || err != nil {
|
||||
t.Errorf("invalid response or error: %v", err)
|
||||
}
|
||||
if len(downloadPool.readers) != 0 || len(uploadPool.readers) != 0 {
|
||||
t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers)
|
||||
}
|
||||
|
||||
// Upload request
|
||||
uploadPool.reset()
|
||||
downloadPool.reset()
|
||||
req2, resp2 := base.add(&http.Request{
|
||||
Body: uploadBody,
|
||||
}, &http.Response{})
|
||||
resp, err = rt.RoundTrip(req2)
|
||||
if resp != resp2 || err != nil {
|
||||
t.Errorf("invalid response or error: %v", err)
|
||||
}
|
||||
if len(downloadPool.readers) != 0 || len(uploadPool.readers) != 1 {
|
||||
t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers)
|
||||
}
|
||||
|
||||
// Download request
|
||||
uploadPool.reset()
|
||||
downloadPool.reset()
|
||||
req3, resp3 := base.add(&http.Request{}, &http.Response{Body: downloadBody})
|
||||
resp, err = rt.RoundTrip(req3)
|
||||
if resp != resp3 || err != nil {
|
||||
t.Errorf("invalid response or error: %v", err)
|
||||
}
|
||||
if len(downloadPool.readers) != 1 || len(uploadPool.readers) != 0 {
|
||||
t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers)
|
||||
}
|
||||
|
||||
// Upload/Download request
|
||||
uploadPool.reset()
|
||||
downloadPool.reset()
|
||||
req4, resp4 := base.add(&http.Request{Body: uploadBody}, &http.Response{Body: downloadBody})
|
||||
resp, err = rt.RoundTrip(req4)
|
||||
if resp != resp4 || err != nil {
|
||||
t.Errorf("invalid response or error: %v", err)
|
||||
}
|
||||
if len(downloadPool.readers) != 1 || len(uploadPool.readers) != 1 {
|
||||
t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user