diff --git a/cli/storage_azure.go b/cli/storage_azure.go index 32086e05e..b66b8fdfd 100644 --- a/cli/storage_azure.go +++ b/cli/storage_azure.go @@ -20,8 +20,8 @@ func (c *storageAzureFlags) setup(_ storageProviderServices, cmd *kingpin.CmdCla cmd.Flag("storage-domain", "Azure storage domain").Envar("AZURE_STORAGE_DOMAIN").StringVar(&c.azOptions.StorageDomain) cmd.Flag("sas-token", "Azure SAS Token").Envar("AZURE_STORAGE_SAS_TOKEN").StringVar(&c.azOptions.SASToken) cmd.Flag("prefix", "Prefix to use for objects in the bucket").StringVar(&c.azOptions.Prefix) - cmd.Flag("max-download-speed", "Limit the download speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.azOptions.MaxDownloadSpeedBytesPerSecond) - cmd.Flag("max-upload-speed", "Limit the upload speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.azOptions.MaxUploadSpeedBytesPerSecond) + + commonThrottlingFlags(cmd, &c.azOptions.Limits) } func (c *storageAzureFlags) connect(ctx context.Context, isCreate bool, formatVersion int) (blob.Storage, error) { diff --git a/cli/storage_b2.go b/cli/storage_b2.go index 9a3991ef1..fe537e6dc 100644 --- a/cli/storage_b2.go +++ b/cli/storage_b2.go @@ -18,8 +18,7 @@ func (c *storageB2Flags) setup(_ storageProviderServices, cmd *kingpin.CmdClause cmd.Flag("key-id", "Key ID (overrides B2_KEY_ID environment variable)").Required().Envar("B2_KEY_ID").StringVar(&c.b2options.KeyID) cmd.Flag("key", "Secret key (overrides B2_KEY environment variable)").Required().Envar("B2_KEY").StringVar(&c.b2options.Key) cmd.Flag("prefix", "Prefix to use for objects in the bucket").StringVar(&c.b2options.Prefix) - cmd.Flag("max-download-speed", "Limit the download speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.b2options.MaxDownloadSpeedBytesPerSecond) - cmd.Flag("max-upload-speed", "Limit the upload speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.b2options.MaxUploadSpeedBytesPerSecond) + commonThrottlingFlags(cmd, &c.b2options.Limits) } func (c *storageB2Flags) connect(ctx context.Context, isCreate bool, formatVersion int) (blob.Storage, error) { diff --git a/cli/storage_filesystem.go b/cli/storage_filesystem.go index e17d18229..c93101ad4 100644 --- a/cli/storage_filesystem.go +++ b/cli/storage_filesystem.go @@ -36,6 +36,8 @@ func (c *storageFilesystemFlags) setup(_ storageProviderServices, cmd *kingpin.C cmd.Flag("dir-mode", "Mode of newly directory files (0700)").PlaceHolder("MODE").StringVar(&c.connectDirMode) cmd.Flag("flat", "Use flat directory structure").BoolVar(&c.connectFlat) cmd.Flag("list-parallelism", "Set list parallelism").Hidden().IntVar(&c.options.ListParallelism) + + commonThrottlingFlags(cmd, &c.options.Limits) } func (c *storageFilesystemFlags) connect(ctx context.Context, isCreate bool, formatVersion int) (blob.Storage, error) { diff --git a/cli/storage_gcs.go b/cli/storage_gcs.go index c0d323ec1..441ecceb9 100644 --- a/cli/storage_gcs.go +++ b/cli/storage_gcs.go @@ -23,9 +23,9 @@ func (c *storageGCSFlags) setup(_ storageProviderServices, cmd *kingpin.CmdClaus cmd.Flag("prefix", "Prefix to use for objects in the bucket").StringVar(&c.options.Prefix) cmd.Flag("read-only", "Use read-only GCS scope to prevent write access").BoolVar(&c.options.ReadOnly) cmd.Flag("credentials-file", "Use the provided JSON file with credentials").ExistingFileVar(&c.options.ServiceAccountCredentialsFile) - cmd.Flag("max-download-speed", "Limit the download speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.options.MaxDownloadSpeedBytesPerSecond) - cmd.Flag("max-upload-speed", "Limit the upload speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.options.MaxUploadSpeedBytesPerSecond) cmd.Flag("embed-credentials", "Embed GCS credentials JSON in Kopia configuration").BoolVar(&c.embedCredentials) + + commonThrottlingFlags(cmd, &c.options.Limits) } func (c *storageGCSFlags) connect(ctx context.Context, isCreate bool, formatVersion int) (blob.Storage, error) { diff --git a/cli/storage_providers.go b/cli/storage_providers.go index caa2a29c3..31afd1b1e 100644 --- a/cli/storage_providers.go +++ b/cli/storage_providers.go @@ -6,6 +6,7 @@ "github.com/alecthomas/kingpin" "github.com/kopia/kopia/repo/blob" + "github.com/kopia/kopia/repo/blob/throttling" ) type storageProviderServices interface { @@ -35,3 +36,8 @@ type storageProvider struct { {"sftp", "an SFTP storage", func() storageFlags { return &storageSFTPFlags{} }}, {"webdav", "a WebDAV storage", func() storageFlags { return &storageWebDAVFlags{} }}, } + +func commonThrottlingFlags(cmd *kingpin.CmdClause, limits *throttling.Limits) { + cmd.Flag("max-download-speed", "Limit the download speed.").PlaceHolder("BYTES_PER_SEC").FloatVar(&limits.DownloadBytesPerSecond) + cmd.Flag("max-upload-speed", "Limit the upload speed.").PlaceHolder("BYTES_PER_SEC").FloatVar(&limits.UploadBytesPerSecond) +} diff --git a/cli/storage_rclone.go b/cli/storage_rclone.go index 53263a009..6c7e6a4a9 100644 --- a/cli/storage_rclone.go +++ b/cli/storage_rclone.go @@ -28,6 +28,8 @@ func (c *storageRcloneFlags) setup(_ storageProviderServices, cmd *kingpin.CmdCl cmd.Flag("rclone-nowait-for-transfers", "Don't wait for transfers when closing storage").Hidden().BoolVar(&c.opt.NoWaitForTransfers) cmd.Flag("list-parallelism", "Set list parallelism").Hidden().IntVar(&c.opt.ListParallelism) cmd.Flag("atomic-writes", "Assume provider writes are atomic").Default("true").BoolVar(&c.opt.AtomicWrites) + + commonThrottlingFlags(cmd, &c.opt.Limits) } func (c *storageRcloneFlags) connect(ctx context.Context, isCreate bool, formatVersion int) (blob.Storage, error) { diff --git a/cli/storage_s3.go b/cli/storage_s3.go index 813057ee7..93b39f5cc 100644 --- a/cli/storage_s3.go +++ b/cli/storage_s3.go @@ -25,8 +25,8 @@ func (c *storageS3Flags) setup(_ storageProviderServices, cmd *kingpin.CmdClause cmd.Flag("prefix", "Prefix to use for objects in the bucket").StringVar(&c.s3options.Prefix) cmd.Flag("disable-tls", "Disable TLS security (HTTPS)").BoolVar(&c.s3options.DoNotUseTLS) cmd.Flag("disable-tls-verification", "Disable TLS (HTTPS) certificate verification").BoolVar(&c.s3options.DoNotVerifyTLS) - cmd.Flag("max-download-speed", "Limit the download speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.s3options.MaxDownloadSpeedBytesPerSecond) - cmd.Flag("max-upload-speed", "Limit the upload speed.").PlaceHolder("BYTES_PER_SEC").IntVar(&c.s3options.MaxUploadSpeedBytesPerSecond) + + commonThrottlingFlags(cmd, &c.s3options.Limits) var pointInTimeStr string diff --git a/cli/storage_sftp.go b/cli/storage_sftp.go index eec03ba7e..96c8a56b0 100644 --- a/cli/storage_sftp.go +++ b/cli/storage_sftp.go @@ -41,6 +41,8 @@ func (c *storageSFTPFlags) setup(_ storageProviderServices, cmd *kingpin.CmdClau cmd.Flag("flat", "Use flat directory structure").BoolVar(&c.connectFlat) cmd.Flag("list-parallelism", "Set list parallelism").Hidden().IntVar(&c.options.ListParallelism) + + commonThrottlingFlags(cmd, &c.options.Limits) } func (c *storageSFTPFlags) getOptions(formatVersion int) (*sftp.Options, error) { diff --git a/cli/storage_webdav.go b/cli/storage_webdav.go index bbe95298e..472674c9b 100644 --- a/cli/storage_webdav.go +++ b/cli/storage_webdav.go @@ -22,6 +22,8 @@ func (c *storageWebDAVFlags) setup(_ storageProviderServices, cmd *kingpin.CmdCl cmd.Flag("webdav-password", "WebDAV password").Envar("KOPIA_WEBDAV_PASSWORD").StringVar(&c.options.Password) cmd.Flag("list-parallelism", "Set list parallelism").Hidden().IntVar(&c.options.ListParallelism) cmd.Flag("atomic-writes", "Assume WebDAV provider implements atomic writes").BoolVar(&c.options.AtomicWrites) + + commonThrottlingFlags(cmd, &c.options.Limits) } func (c *storageWebDAVFlags) connect(ctx context.Context, isCreate bool, formatVersion int) (blob.Storage, error) { diff --git a/go.mod b/go.mod index ba9a7d637..2ca564da9 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/aws/aws-sdk-go v1.41.16 github.com/chmduquesne/rollinghash v4.0.0+incompatible github.com/dustinkirkland/golang-petname v0.0.0-20191129215211-8e5a1ed0cff0 - github.com/efarrer/iothrottler v0.0.3 github.com/fatih/color v1.13.0 github.com/foomo/htpasswd v0.0.0-20200116085101-e3a90e78da9c github.com/frankban/quicktest v1.13.1 // indirect diff --git a/go.sum b/go.sum index 774981c84..681b9b95d 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,6 @@ github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5m github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= -github.com/efarrer/iothrottler v0.0.3 h1:6m8eKBQ1ouigjXQoBxwEWz12fUGGYfYppEJVcyZFcYg= -github.com/efarrer/iothrottler v0.0.3/go.mod h1:zGWF5N0NKSCskcPFytDAFwI121DdU/NfW4XOjpTR+ys= github.com/envoyproxy/go-control-plane v0.6.9/go.mod h1:SBwIajubJHhxtWwsL9s8ss4safvEdbitLhGGK48rN6g= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -677,8 +675,6 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/studio-b12/gowebdav v0.0.0-20210917133250-a3a86976a1df h1:C+J/LwTqP8gRPt1MdSzBNZP0OYuDm5wsmDKgwpLjYzo= -github.com/studio-b12/gowebdav v0.0.0-20210917133250-a3a86976a1df/go.mod h1:gCcfDlA1Y7GqOaeEKw5l9dOGx1VLdc/HuQSlQAaZ30s= github.com/studio-b12/gowebdav v0.0.0-20211106090535-29e74efa701f h1:SLJx0nHhb2ZLlYNMAbrYsjwmVwXx4yRT48lNIxOp7ts= github.com/studio-b12/gowebdav v0.0.0-20211106090535-29e74efa701f/go.mod h1:gCcfDlA1Y7GqOaeEKw5l9dOGx1VLdc/HuQSlQAaZ30s= github.com/tg123/go-htpasswd v1.2.0 h1:UKp34m9H467/xklxUxU15wKRru7fwXoTojtxg25ITF0= diff --git a/internal/server/api_repo.go b/internal/server/api_repo.go index dcee0bb2e..1f22e0439 100644 --- a/internal/server/api_repo.go +++ b/internal/server/api_repo.go @@ -14,6 +14,7 @@ "github.com/kopia/kopia/internal/serverapi" "github.com/kopia/kopia/repo" "github.com/kopia/kopia/repo/blob" + "github.com/kopia/kopia/repo/blob/throttling" "github.com/kopia/kopia/repo/compression" "github.com/kopia/kopia/repo/encryption" "github.com/kopia/kopia/repo/hashing" @@ -243,6 +244,33 @@ func (s *Server) handleRepoSupportedAlgorithms(ctx context.Context, r *http.Requ return res, nil } +func (s *Server) handleRepoGetThrottle(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + dr, ok := s.rep.(repo.DirectRepository) + if !ok { + return nil, requestError(serverapi.ErrorStorageConnection, "no direct storage connection") + } + + return dr.Throttler().Limits(), nil +} + +func (s *Server) handleRepoSetThrottle(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + dr, ok := s.rep.(repo.DirectRepository) + if !ok { + return nil, requestError(serverapi.ErrorStorageConnection, "no direct storage connection") + } + + var req throttling.Limits + if err := json.Unmarshal(body, &req); err != nil { + return nil, requestError(serverapi.ErrorMalformedRequest, "unable to decode request: "+err.Error()) + } + + if err := dr.Throttler().SetLimits(req); err != nil { + return nil, requestError(serverapi.ErrorMalformedRequest, "unable to set limits: "+err.Error()) + } + + return &serverapi.Empty{}, nil +} + func (s *Server) getConnectOptions(cliOpts repo.ClientOptions) *repo.ConnectOptions { o := *s.options.ConnectOptions o.ClientOptions = o.ClientOptions.Override(cliOpts) diff --git a/internal/server/server.go b/internal/server/server.go index 44bfa851b..25c018b94 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -107,6 +107,8 @@ func (s *Server) APIHandlers(legacyAPI bool) http.Handler { m.HandleFunc("/api/v1/repo/disconnect", s.handleAPI(requireUIUser, s.handleRepoDisconnect)).Methods(http.MethodPost) m.HandleFunc("/api/v1/repo/algorithms", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoGetThrottle)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoSetThrottle)).Methods(http.MethodPut) if legacyAPI { m.HandleFunc("/api/v1/repo/parameters", s.handleAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet) diff --git a/internal/serverapi/client_wrappers.go b/internal/serverapi/client_wrappers.go index 18608359c..820ff33e9 100644 --- a/internal/serverapi/client_wrappers.go +++ b/internal/serverapi/client_wrappers.go @@ -8,6 +8,7 @@ "github.com/kopia/kopia/internal/apiclient" "github.com/kopia/kopia/internal/uitask" + "github.com/kopia/kopia/repo/blob/throttling" "github.com/kopia/kopia/repo/object" "github.com/kopia/kopia/snapshot" ) @@ -96,6 +97,25 @@ func Status(ctx context.Context, c *apiclient.KopiaAPIClient) (*StatusResponse, return resp, nil } +// GetThrottlingLimits gets the throttling limits. +func GetThrottlingLimits(ctx context.Context, c *apiclient.KopiaAPIClient) (throttling.Limits, error) { + resp := throttling.Limits{} + if err := c.Get(ctx, "repo/throttle", nil, &resp); err != nil { + return throttling.Limits{}, errors.Wrap(err, "throttling") + } + + return resp, nil +} + +// SetThrottlingLimits sets the throttling limits. +func SetThrottlingLimits(ctx context.Context, c *apiclient.KopiaAPIClient, l throttling.Limits) error { + if err := c.Put(ctx, "repo/throttle", &l, &Empty{}); err != nil { + return errors.Wrap(err, "throttling") + } + + return nil +} + // ListSources lists the snapshot sources managed by the server. func ListSources(ctx context.Context, c *apiclient.KopiaAPIClient, match *snapshot.SourceInfo) (*SourcesResponse, error) { resp := &SourcesResponse{} diff --git a/internal/throttle/round_tripper.go b/internal/throttle/round_tripper.go deleted file mode 100644 index bfd995a61..000000000 --- a/internal/throttle/round_tripper.go +++ /dev/null @@ -1,55 +0,0 @@ -// Package throttle implements helpers for throttling uploads and downloads. -package throttle - -import ( - "io" - "net/http" - - "github.com/pkg/errors" -) - -type throttlerPool interface { - AddReader(io.ReadCloser) (io.ReadCloser, error) -} - -type throttlingRoundTripper struct { - base http.RoundTripper - downloadPool throttlerPool - uploadPool throttlerPool -} - -func (rt *throttlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if req.Body != nil && rt.uploadPool != nil { - var err error - - req.Body, err = rt.uploadPool.AddReader(req.Body) - if err != nil { - return nil, errors.Wrap(err, "unable to attach request throttler") - } - } - - resp, err := rt.base.RoundTrip(req) - - if resp != nil && resp.Body != nil && rt.downloadPool != nil { - resp.Body, err = rt.downloadPool.AddReader(resp.Body) - if err != nil { - return nil, errors.Wrap(err, "unable to attach response throttler") - } - } - - // nolint:wrapcheck - return resp, err -} - -// NewRoundTripper returns http.RoundTripper that throttles upload and downloads. -func NewRoundTripper(base http.RoundTripper, downloadPool, uploadPool throttlerPool) http.RoundTripper { - if base == nil { - base = http.DefaultTransport - } - - return &throttlingRoundTripper{ - base: base, - downloadPool: downloadPool, - uploadPool: uploadPool, - } -} diff --git a/internal/throttle/round_tripper_test.go b/internal/throttle/round_tripper_test.go deleted file mode 100644 index 24e99ff27..000000000 --- a/internal/throttle/round_tripper_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package throttle - -import ( - "bytes" - "io" - "net/http" - "testing" - - "github.com/pkg/errors" -) - -type baseRoundTripper struct { - responses map[*http.Request]*http.Response -} - -func (rt *baseRoundTripper) add(req *http.Request, resp *http.Response) (*http.Request, *http.Response) { - rt.responses[req] = resp - return req, resp -} - -func (rt *baseRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - resp := rt.responses[req] - if resp != nil { - return resp, nil - } - - return nil, errors.Errorf("error occurred") -} - -type fakePool struct { - readers []io.ReadCloser -} - -func (fp *fakePool) reset() { - fp.readers = nil -} - -func (fp *fakePool) AddReader(r io.ReadCloser) (io.ReadCloser, error) { - fp.readers = append(fp.readers, r) - return r, nil -} - -//nolint:gocyclo -func TestRoundTripper(t *testing.T) { - downloadBody := io.NopCloser(bytes.NewReader([]byte("data1"))) - uploadBody := io.NopCloser(bytes.NewReader([]byte("data1"))) - - base := &baseRoundTripper{ - responses: make(map[*http.Request]*http.Response), - } - downloadPool := &fakePool{} - uploadPool := &fakePool{} - rt := NewRoundTripper(base, downloadPool, uploadPool) - - // Empty request (no request, no response) - uploadPool.reset() - downloadPool.reset() - - req1, resp1 := base.add(&http.Request{}, &http.Response{}) //nolint:bodyclose - resp, err := rt.RoundTrip(req1) //nolint:bodyclose - - if resp != resp1 || err != nil { - t.Errorf("invalid response or error: %v", err) - } - - if len(downloadPool.readers) != 0 || len(uploadPool.readers) != 0 { - t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers) - } - - // Upload request - uploadPool.reset() - downloadPool.reset() - - req2, resp2 := base.add(&http.Request{ //nolint:bodyclose - Body: uploadBody, - }, &http.Response{}) - resp, err = rt.RoundTrip(req2) //nolint:bodyclose - - if resp != resp2 || err != nil { - t.Errorf("invalid response or error: %v", err) - } - - if len(downloadPool.readers) != 0 || len(uploadPool.readers) != 1 { - t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers) - } - - // Download request - uploadPool.reset() - downloadPool.reset() - - req3, resp3 := base.add(&http.Request{}, &http.Response{Body: downloadBody}) //nolint:bodyclose - resp, err = rt.RoundTrip(req3) //nolint:bodyclose - - if resp != resp3 || err != nil { - t.Errorf("invalid response or error: %v", err) - } - - if len(downloadPool.readers) != 1 || len(uploadPool.readers) != 0 { - t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers) - } - - // Upload/Download request - uploadPool.reset() - downloadPool.reset() - - req4, resp4 := base.add(&http.Request{Body: uploadBody}, &http.Response{Body: downloadBody}) //nolint:bodyclose - - resp, err = rt.RoundTrip(req4) //nolint:bodyclose - if resp != resp4 || err != nil { - t.Errorf("invalid response or error: %v", err) - } - - if len(downloadPool.readers) != 1 || len(uploadPool.readers) != 1 { - t.Errorf("invalid pool contents: %v %v", downloadPool.readers, uploadPool.readers) - } -} diff --git a/repo/blob/azure/azure_options.go b/repo/blob/azure/azure_options.go index 0cfd8ea8a..bf652b9a2 100644 --- a/repo/blob/azure/azure_options.go +++ b/repo/blob/azure/azure_options.go @@ -1,5 +1,9 @@ package azure +import ( + "github.com/kopia/kopia/repo/blob/throttling" +) + // Options defines options for Azure blob storage storage. type Options struct { // Container is the name of the azure storage container where data is stored. @@ -17,6 +21,5 @@ type Options struct { StorageDomain string `json:"storageDomain,omitempty"` - MaxUploadSpeedBytesPerSecond int `json:"maxUploadSpeedBytesPerSecond,omitempty"` - MaxDownloadSpeedBytesPerSecond int `json:"maxDownloadSpeedBytesPerSecond,omitempty"` + throttling.Limits } diff --git a/repo/blob/azure/azure_storage.go b/repo/blob/azure/azure_storage.go index 4f9726df0..5e03551d2 100644 --- a/repo/blob/azure/azure_storage.go +++ b/repo/blob/azure/azure_storage.go @@ -10,7 +10,6 @@ "github.com/Azure/azure-pipeline-go/pipeline" "github.com/Azure/azure-storage-blob-go/azblob" - "github.com/efarrer/iothrottler" "github.com/pkg/errors" gblob "gocloud.dev/blob" "gocloud.dev/blob/azureblob" @@ -32,9 +31,6 @@ type azStorage struct { ctx context.Context bucket *gblob.Bucket - - downloadThrottler *iothrottler.IOThrottlerPool - uploadThrottler *iothrottler.IOThrottlerPool } func (az *azStorage) GetBlob(ctx context.Context, b blob.ID, offset, length int64, output blob.OutputBuffer) error { @@ -50,13 +46,8 @@ func (az *azStorage) GetBlob(ctx context.Context, b blob.ID, offset, length int6 defer reader.Close() //nolint:errcheck - throttled, err := az.downloadThrottler.AddReader(reader) - if err != nil { - return errors.Wrap(err, "AddReader") - } - // nolint:wrapcheck - return iocopy.JustCopy(output, throttled) + return iocopy.JustCopy(output, reader) } if err := attempt(); err != nil { @@ -112,12 +103,6 @@ func (az *azStorage) PutBlob(ctx context.Context, b blob.ID, data blob.Bytes, op ctx, cancel := context.WithCancel(ctx) defer cancel() - throttled, err := az.uploadThrottler.AddReader(io.NopCloser(data.Reader())) - if err != nil { - // nolint:wrapcheck - return err - } - // create azure Bucket writer writer, err := az.bucket.NewWriter(ctx, az.getObjectNameString(b), &gblob.WriterOptions{ContentType: "application/x-kopia"}) if err != nil { @@ -125,7 +110,7 @@ func (az *azStorage) PutBlob(ctx context.Context, b blob.ID, data blob.Bytes, op return err } - if err := iocopy.JustCopy(writer, throttled); err != nil { + if err := iocopy.JustCopy(writer, data.Reader()); err != nil { // cancel context before closing the writer causes it to abandon the upload. cancel() @@ -208,14 +193,6 @@ func (az *azStorage) FlushCaches(ctx context.Context) error { return nil } -func toBandwidth(bytesPerSecond int) iothrottler.Bandwidth { - if bytesPerSecond <= 0 { - return iothrottler.Unlimited - } - - return iothrottler.Bandwidth(bytesPerSecond) * iothrottler.BytesPerSecond -} - // New creates new Azure Blob Storage-backed storage with specified options: // // - the 'Container', 'StorageAccount' and 'StorageKey' fields are required and all other parameters are optional. @@ -258,15 +235,10 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) { return nil, errors.Wrap(err, "unable to open bucket") } - downloadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxDownloadSpeedBytesPerSecond)) - uploadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxUploadSpeedBytesPerSecond)) - az := retrying.NewWrapper(&azStorage{ - Options: *opt, - ctx: ctx, - bucket: bucket, - downloadThrottler: downloadThrottler, - uploadThrottler: uploadThrottler, + Options: *opt, + ctx: ctx, + bucket: bucket, }) // verify Azure connection is functional by listing blobs in a bucket, which will fail if the container diff --git a/repo/blob/b2/b2_options.go b/repo/blob/b2/b2_options.go index c0c805c85..505297735 100644 --- a/repo/blob/b2/b2_options.go +++ b/repo/blob/b2/b2_options.go @@ -1,5 +1,7 @@ package b2 +import "github.com/kopia/kopia/repo/blob/throttling" + // Options defines options for B2-based storage. type Options struct { // BucketName is the name of the bucket where data is stored. @@ -11,6 +13,5 @@ type Options struct { KeyID string `json:"keyID"` Key string `json:"key" kopia:"sensitive"` - MaxUploadSpeedBytesPerSecond int `json:"maxUploadSpeedBytesPerSecond,omitempty"` - MaxDownloadSpeedBytesPerSecond int `json:"maxDownloadSpeedBytesPerSecond,omitempty"` + throttling.Limits } diff --git a/repo/blob/b2/b2_storage.go b/repo/blob/b2/b2_storage.go index ffd514acb..d1b7ab9ed 100644 --- a/repo/blob/b2/b2_storage.go +++ b/repo/blob/b2/b2_storage.go @@ -4,12 +4,10 @@ import ( "context" "fmt" - "io" "net/http" "strings" "time" - "github.com/efarrer/iothrottler" "github.com/pkg/errors" "gopkg.in/kothar/go-backblaze.v0" @@ -29,9 +27,6 @@ type b2Storage struct { cli *backblaze.B2 bucket *backblaze.Bucket - - downloadThrottler *iothrottler.IOThrottlerPool - uploadThrottler *iothrottler.IOThrottlerPool } func (s *b2Storage) GetBlob(ctx context.Context, id blob.ID, offset, length int64, output blob.OutputBuffer) error { @@ -59,17 +54,12 @@ func (s *b2Storage) GetBlob(ctx context.Context, id blob.ID, offset, length int6 } defer r.Close() //nolint:errcheck - throttled, err := s.downloadThrottler.AddReader(r) - if err != nil { - return errors.Wrap(err, "DownloadFileRangeByName") - } - if length == 0 { return nil } // nolint:wrapcheck - return iocopy.JustCopy(output, throttled) + return iocopy.JustCopy(output, r) } if err := attempt(); err != nil { @@ -148,13 +138,8 @@ func translateError(err error) error { } func (s *b2Storage) PutBlob(ctx context.Context, id blob.ID, data blob.Bytes, opts blob.PutOptions) error { - throttled, err := s.uploadThrottler.AddReader(io.NopCloser(data.Reader())) - if err != nil { - return translateError(err) - } - fileName := s.getObjectNameString(id) - _, err = s.bucket.UploadFile(fileName, nil, throttled) + _, err := s.bucket.UploadFile(fileName, nil, data.Reader()) return translateError(err) } @@ -238,14 +223,6 @@ func (s *b2Storage) String() string { return fmt.Sprintf("b2://%s/%s", s.BucketName, s.Prefix) } -func toBandwidth(bytesPerSecond int) iothrottler.Bandwidth { - if bytesPerSecond <= 0 { - return iothrottler.Unlimited - } - - return iothrottler.Bandwidth(bytesPerSecond) * iothrottler.BytesPerSecond -} - // New creates new B2-backed storage with specified options. func New(ctx context.Context, opt *Options) (blob.Storage, error) { if opt.BucketName == "" { @@ -257,9 +234,6 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) { return nil, errors.Wrap(err, "unable to create client") } - downloadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxDownloadSpeedBytesPerSecond)) - uploadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxUploadSpeedBytesPerSecond)) - bucket, err := cli.Bucket(opt.BucketName) if err != nil { return nil, errors.Wrapf(err, "cannot open bucket %q", opt.BucketName) @@ -270,12 +244,10 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) { } return retrying.NewWrapper(&b2Storage{ - Options: *opt, - ctx: ctx, - cli: cli, - bucket: bucket, - downloadThrottler: downloadThrottler, - uploadThrottler: uploadThrottler, + Options: *opt, + ctx: ctx, + cli: cli, + bucket: bucket, }), nil } diff --git a/repo/blob/filesystem/filesystem_options.go b/repo/blob/filesystem/filesystem_options.go index 455c1bde4..4271e7af3 100644 --- a/repo/blob/filesystem/filesystem_options.go +++ b/repo/blob/filesystem/filesystem_options.go @@ -4,6 +4,7 @@ "os" "github.com/kopia/kopia/repo/blob/sharded" + "github.com/kopia/kopia/repo/blob/throttling" ) // Options defines options for Filesystem-backed storage. @@ -17,6 +18,7 @@ type Options struct { FileGID *int `json:"gid,omitempty"` sharded.Options + throttling.Limits } func (fso *Options) fileMode() os.FileMode { diff --git a/repo/blob/gcs/gcs_options.go b/repo/blob/gcs/gcs_options.go index 2a71d549f..78f1b9856 100644 --- a/repo/blob/gcs/gcs_options.go +++ b/repo/blob/gcs/gcs_options.go @@ -1,6 +1,10 @@ package gcs -import "encoding/json" +import ( + "encoding/json" + + "github.com/kopia/kopia/repo/blob/throttling" +) // Options defines options Google Cloud Storage-backed storage. type Options struct { @@ -19,7 +23,5 @@ type Options struct { // ReadOnly causes GCS connection to be opened with read-only scope to prevent accidental mutations. ReadOnly bool `json:"readOnly,omitempty"` - MaxUploadSpeedBytesPerSecond int `json:"maxUploadSpeedBytesPerSecond,omitempty"` - - MaxDownloadSpeedBytesPerSecond int `json:"maxDownloadSpeedBytesPerSecond,omitempty"` + throttling.Limits } diff --git a/repo/blob/gcs/gcs_storage.go b/repo/blob/gcs/gcs_storage.go index 93cdac9a3..1c6bda700 100644 --- a/repo/blob/gcs/gcs_storage.go +++ b/repo/blob/gcs/gcs_storage.go @@ -10,7 +10,6 @@ "time" gcsclient "cloud.google.com/go/storage" - "github.com/efarrer/iothrottler" "github.com/pkg/errors" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -20,7 +19,6 @@ "github.com/kopia/kopia/internal/clock" "github.com/kopia/kopia/internal/iocopy" - "github.com/kopia/kopia/internal/throttle" "github.com/kopia/kopia/repo/blob" "github.com/kopia/kopia/repo/blob/retrying" ) @@ -35,9 +33,6 @@ type gcsStorage struct { storageClient *gcsclient.Client bucket *gcsclient.BucketHandle - - downloadThrottler *iothrottler.IOThrottlerPool - uploadThrottler *iothrottler.IOThrottlerPool } func (gcs *gcsStorage) GetBlob(ctx context.Context, b blob.ID, offset, length int64, output blob.OutputBuffer) error { @@ -185,14 +180,6 @@ func (gcs *gcsStorage) FlushCaches(ctx context.Context) error { return nil } -func toBandwidth(bytesPerSecond int) iothrottler.Bandwidth { - if bytesPerSecond <= 0 { - return iothrottler.Unlimited - } - - return iothrottler.Bandwidth(bytesPerSecond) * iothrottler.BytesPerSecond -} - func tokenSourceFromCredentialsFile(ctx context.Context, fn string, scopes ...string) (oauth2.TokenSource, error) { data, err := os.ReadFile(fn) //nolint:gosec if err != nil { @@ -244,11 +231,7 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) { return nil, errors.Wrap(err, "unable to initialize token source") } - downloadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxDownloadSpeedBytesPerSecond)) - uploadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxUploadSpeedBytesPerSecond)) - hc := oauth2.NewClient(ctx, ts) - hc.Transport = throttle.NewRoundTripper(hc.Transport, downloadThrottler, uploadThrottler) cli, err := gcsclient.NewClient(ctx, option.WithHTTPClient(hc)) if err != nil { @@ -260,11 +243,9 @@ func New(ctx context.Context, opt *Options) (blob.Storage, error) { } gcs := &gcsStorage{ - Options: *opt, - storageClient: cli, - bucket: cli.Bucket(opt.BucketName), - downloadThrottler: downloadThrottler, - uploadThrottler: uploadThrottler, + Options: *opt, + storageClient: cli, + bucket: cli.Bucket(opt.BucketName), } // verify GCS connection is functional by listing blobs in a bucket, which will fail if the bucket diff --git a/repo/blob/rclone/rclone_options.go b/repo/blob/rclone/rclone_options.go index a44dc4c52..6c9b849b3 100644 --- a/repo/blob/rclone/rclone_options.go +++ b/repo/blob/rclone/rclone_options.go @@ -1,6 +1,9 @@ package rclone -import "github.com/kopia/kopia/repo/blob/sharded" +import ( + "github.com/kopia/kopia/repo/blob/sharded" + "github.com/kopia/kopia/repo/blob/throttling" +) // Options defines options for RClone storage. type Options struct { @@ -15,4 +18,5 @@ type Options struct { AtomicWrites bool `json:"atomicWrites"` sharded.Options + throttling.Limits } diff --git a/repo/blob/s3/s3_options.go b/repo/blob/s3/s3_options.go index c32e340bd..1b20ce490 100644 --- a/repo/blob/s3/s3_options.go +++ b/repo/blob/s3/s3_options.go @@ -1,6 +1,10 @@ package s3 -import "time" +import ( + "time" + + "github.com/kopia/kopia/repo/blob/throttling" +) // Options defines options for S3-based storage. type Options struct { @@ -21,9 +25,7 @@ type Options struct { // Region is an optional region to pass in authorization header. Region string `json:"region,omitempty"` - MaxUploadSpeedBytesPerSecond int `json:"maxUploadSpeedBytesPerSecond,omitempty"` - - MaxDownloadSpeedBytesPerSecond int `json:"maxDownloadSpeedBytesPerSecond,omitempty"` + throttling.Limits // PointInTime specifies a view of the (versioned) store at that time PointInTime *time.Time `json:"pointInTime,omitempty"` diff --git a/repo/blob/s3/s3_storage.go b/repo/blob/s3/s3_storage.go index 807d7ec93..5b4ed4ffb 100644 --- a/repo/blob/s3/s3_storage.go +++ b/repo/blob/s3/s3_storage.go @@ -12,7 +12,6 @@ "sync/atomic" "time" - "github.com/efarrer/iothrottler" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/pkg/errors" @@ -35,9 +34,7 @@ type s3Storage struct { cli *minio.Client - downloadThrottler *iothrottler.IOThrottlerPool - uploadThrottler *iothrottler.IOThrottlerPool - storageConfig *StorageConfig + storageConfig *StorageConfig } func (s *s3Storage) GetBlob(ctx context.Context, b blob.ID, offset, length int64, output blob.OutputBuffer) error { @@ -72,17 +69,12 @@ func (s *s3Storage) getBlobWithVersion(ctx context.Context, b blob.ID, version s defer o.Close() //nolint:errcheck - throttled, err := s.downloadThrottler.AddReader(o) - if err != nil { - return errors.Wrap(err, "AddReader") - } - if length == 0 { return nil } // nolint:wrapcheck - return iocopy.JustCopy(output, throttled) + return iocopy.JustCopy(output, o) } if err := attempt(); err != nil { @@ -138,11 +130,6 @@ func (s *s3Storage) PutBlob(ctx context.Context, b blob.ID, data blob.Bytes, opt } func (s *s3Storage) putBlob(ctx context.Context, b blob.ID, data blob.Bytes, opts blob.PutOptions) (versionMetadata, error) { - throttled, err := s.uploadThrottler.AddReader(io.NopCloser(data.Reader())) - if err != nil { - return versionMetadata{}, errors.Wrap(err, "AddReader") - } - var ( storageClass = s.storageConfig.getStorageClassForBlobID(b) retentionMode minio.RetentionMode @@ -158,7 +145,7 @@ func (s *s3Storage) putBlob(ctx context.Context, b blob.ID, data blob.Bytes, opt retainUntilDate = clock.Now().Add(opts.RetentionPeriod).UTC() } - uploadInfo, err := s.cli.PutObject(ctx, s.BucketName, s.getObjectNameString(b), throttled, int64(data.Length()), minio.PutObjectOptions{ + uploadInfo, err := s.cli.PutObject(ctx, s.BucketName, s.getObjectNameString(b), data.Reader(), int64(data.Length()), minio.PutObjectOptions{ ContentType: "application/x-kopia", SendContentMd5: atomic.LoadInt32(&s.sendMD5) > 0 || // The Content-MD5 header is required for any request to upload an object @@ -272,14 +259,6 @@ func (s *s3Storage) FlushCaches(ctx context.Context) error { return nil } -func toBandwidth(bytesPerSecond int) iothrottler.Bandwidth { - if bytesPerSecond <= 0 { - return iothrottler.Unlimited - } - - return iothrottler.Bandwidth(bytesPerSecond) * iothrottler.BytesPerSecond -} - func getCustomTransport(insecureSkipVerify bool) (transport *http.Transport) { // nolint:gosec customTransport := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: insecureSkipVerify}} @@ -323,9 +302,6 @@ func newStorage(ctx context.Context, opt *Options) (*s3Storage, error) { return nil, errors.Wrap(err, "unable to create client") } - downloadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxDownloadSpeedBytesPerSecond)) - uploadThrottler := iothrottler.NewIOThrottlerPool(toBandwidth(opt.MaxUploadSpeedBytesPerSecond)) - ok, err := cli.BucketExists(ctx, opt.BucketName) if err != nil { return nil, errors.Wrapf(err, "unable to determine if bucket %q exists", opt.BucketName) @@ -336,12 +312,10 @@ func newStorage(ctx context.Context, opt *Options) (*s3Storage, error) { } s := s3Storage{ - Options: *opt, - cli: cli, - sendMD5: 0, - downloadThrottler: downloadThrottler, - uploadThrottler: uploadThrottler, - storageConfig: &StorageConfig{}, + Options: *opt, + cli: cli, + sendMD5: 0, + storageConfig: &StorageConfig{}, } var scOutput gather.WriteBuffer diff --git a/repo/blob/sftp/sftp_options.go b/repo/blob/sftp/sftp_options.go index b84f25cfe..d9796d6f5 100644 --- a/repo/blob/sftp/sftp_options.go +++ b/repo/blob/sftp/sftp_options.go @@ -5,6 +5,7 @@ "path/filepath" "github.com/kopia/kopia/repo/blob/sharded" + "github.com/kopia/kopia/repo/blob/throttling" ) // Options defines options for sftp-backed storage. @@ -26,6 +27,7 @@ type Options struct { SSHArguments string `json:"sshArguments,omitempty"` sharded.Options + throttling.Limits } func (sftpo *Options) knownHostsFile() string { diff --git a/repo/blob/throttling/throttler.go b/repo/blob/throttling/throttler.go new file mode 100644 index 000000000..731248ba4 --- /dev/null +++ b/repo/blob/throttling/throttler.go @@ -0,0 +1,118 @@ +package throttling + +import ( + "context" + "sync" + "time" + + "github.com/pkg/errors" +) + +// SettableThrottler exposes methods to set throttling limits. +type SettableThrottler interface { + Throttler + + Limits() Limits + SetLimits(limits Limits) error +} + +type tokenBucketBasedThrottler struct { + mu sync.Mutex + limits Limits + + readOps *tokenBucket + writeOps *tokenBucket + listOps *tokenBucket + upload *tokenBucket + download *tokenBucket + window time.Duration +} + +func (t *tokenBucketBasedThrottler) BeforeOperation(ctx context.Context, op string) { + switch op { + case operationListBlobs: + t.listOps.Take(ctx, 1) + case operationGetBlob, operationGetMetadata: + t.readOps.Take(ctx, 1) + case operationPutBlob, operationDeleteBlob: + t.writeOps.Take(ctx, 1) + } +} + +func (t *tokenBucketBasedThrottler) BeforeDownload(ctx context.Context, numBytes int64) { + t.download.Take(ctx, float64(numBytes)) +} + +func (t *tokenBucketBasedThrottler) ReturnUnusedDownloadBytes(ctx context.Context, numBytes int64) { + t.download.Return(ctx, float64(numBytes)) +} + +func (t *tokenBucketBasedThrottler) BeforeUpload(ctx context.Context, numBytes int64) { + t.upload.Take(ctx, float64(numBytes)) +} + +func (t *tokenBucketBasedThrottler) Limits() Limits { + t.mu.Lock() + defer t.mu.Unlock() + + return t.limits +} + +// SetLimits overrides limits. +func (t *tokenBucketBasedThrottler) SetLimits(limits Limits) error { + t.mu.Lock() + defer t.mu.Unlock() + + t.limits = limits + + if err := t.readOps.SetLimit(limits.ReadsPerSecond * t.window.Seconds()); err != nil { + return errors.Wrap(err, "ReadsPerSecond") + } + + if err := t.writeOps.SetLimit(limits.WritesPerSecond * t.window.Seconds()); err != nil { + return errors.Wrap(err, "WritesPerSecond") + } + + if err := t.listOps.SetLimit(limits.ListsPerSecond * t.window.Seconds()); err != nil { + return errors.Wrap(err, "ListsPerSecond") + } + + if err := t.upload.SetLimit(limits.UploadBytesPerSecond * t.window.Seconds()); err != nil { + return errors.Wrap(err, "UploadBytesPerSecond") + } + + if err := t.download.SetLimit(limits.DownloadBytesPerSecond * t.window.Seconds()); err != nil { + return errors.Wrap(err, "DownloadBytesPerSecond") + } + + return nil +} + +// Limits encapsulates all limits for a Throttler. +type Limits struct { + ReadsPerSecond float64 `json:"readsPerSecond,omitempty"` + WritesPerSecond float64 `json:"writesPerSecond,omitempty"` + ListsPerSecond float64 `json:"listsPerSecond,omitempty"` + UploadBytesPerSecond float64 `json:"maxUploadSpeedBytesPerSecond,omitempty"` + DownloadBytesPerSecond float64 `json:"maxDownloadSpeedBytesPerSecond,omitempty"` +} + +var _ Throttler = (*tokenBucketBasedThrottler)(nil) + +// NewThrottler returns a Throttler with provided limits. +func NewThrottler(limits Limits, window time.Duration, initialFillRatio float64) (SettableThrottler, error) { + t := &tokenBucketBasedThrottler{ + readOps: newTokenBucket("read-ops", initialFillRatio*limits.ReadsPerSecond*window.Seconds(), 0, window), + writeOps: newTokenBucket("write-ops", initialFillRatio*limits.WritesPerSecond*window.Seconds(), 0, window), + listOps: newTokenBucket("list-ops", initialFillRatio*limits.ListsPerSecond*window.Seconds(), 0, window), + upload: newTokenBucket("upload-bytes", initialFillRatio*limits.UploadBytesPerSecond*window.Seconds(), 0, window), + download: newTokenBucket("download-bytes", initialFillRatio*limits.DownloadBytesPerSecond*window.Seconds(), 0, window), + window: window, + } + + if err := t.SetLimits(limits); err != nil { + return nil, errors.Wrap(err, "invalid limits") + } + + return t, nil +} diff --git a/repo/blob/throttling/throttler_test.go b/repo/blob/throttling/throttler_test.go new file mode 100644 index 000000000..1140d40c0 --- /dev/null +++ b/repo/blob/throttling/throttler_test.go @@ -0,0 +1,135 @@ +package throttling + +import ( + "context" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/kopia/kopia/internal/clock" + "github.com/kopia/kopia/internal/timetrack" +) + +func TestThrottler(t *testing.T) { + limits := Limits{ + ReadsPerSecond: 10, + WritesPerSecond: 10, + ListsPerSecond: 10, + UploadBytesPerSecond: 1000, + DownloadBytesPerSecond: 1000, + } + + const window = time.Second + + ctx := context.Background() + th, err := NewThrottler(limits, window, 0.0 /* start empty */) + require.NoError(t, err) + require.Equal(t, limits, th.Limits()) + + testRateLimiting(t, "DownloadBytesPerSecond", limits.DownloadBytesPerSecond, func(total *int64) { + numBytes := rand.Int63n(1500) + excess := rand.Int63n(10) + th.BeforeDownload(ctx, numBytes+excess) + th.ReturnUnusedDownloadBytes(ctx, excess) + atomic.AddInt64(total, numBytes) + }) + + th, err = NewThrottler(limits, window, 0.0 /* start empty */) + require.NoError(t, err) + testRateLimiting(t, "UploadBytesPerSecond", limits.UploadBytesPerSecond, func(total *int64) { + numBytes := rand.Int63n(1500) + th.BeforeUpload(ctx, numBytes) + atomic.AddInt64(total, numBytes) + }) + + th, err = NewThrottler(limits, window, 0.0 /* start empty */) + require.NoError(t, err) + testRateLimiting(t, "ReadsPerSecond", limits.ReadsPerSecond, func(total *int64) { + th.BeforeOperation(ctx, "GetBlob") + atomic.AddInt64(total, 1) + }) + + th, err = NewThrottler(limits, window, 0.0 /* start empty */) + require.NoError(t, err) + testRateLimiting(t, "WritesPerSecond", limits.WritesPerSecond, func(total *int64) { + th.BeforeOperation(ctx, "PutBlob") + atomic.AddInt64(total, 1) + }) + + th, err = NewThrottler(limits, window, 0.0 /* start empty */) + require.NoError(t, err) + testRateLimiting(t, "ListsPerSecond", limits.ListsPerSecond, func(total *int64) { + th.BeforeOperation(ctx, "ListBlobs") + atomic.AddInt64(total, 1) + }) +} + +func TestThrottlerLargeWindow(t *testing.T) { + limits := Limits{ + ReadsPerSecond: 10, + WritesPerSecond: 10, + ListsPerSecond: 10, + UploadBytesPerSecond: 1000, + DownloadBytesPerSecond: 1000, + } + + ctx := context.Background() + th, err := NewThrottler(limits, time.Minute, 1.0 /* start full */) + require.NoError(t, err) + + // make sure we can consume 60x worth the quota without + timer := timetrack.StartTimer() + + th.BeforeDownload(ctx, 60000) + require.Less(t, timer.Elapsed(), 500*time.Millisecond) + + // subsequent call will block + timer = timetrack.StartTimer() + + th.BeforeDownload(ctx, 1000) + require.Greater(t, timer.Elapsed(), 900*time.Millisecond) +} + +// nolint:thelper +func testRateLimiting(t *testing.T, name string, wantRate float64, worker func(total *int64)) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + const ( + testDuration = 3 * time.Second + numWorkers = 3 + ) + + deadline := clock.Now().Add(testDuration) + total := new(int64) + + timer := timetrack.StartTimer() + + var wg sync.WaitGroup + + for i := 0; i < numWorkers; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for clock.Now().Before(deadline) { + worker(total) + } + }() + } + + wg.Wait() + + numSeconds := timer.Elapsed().Seconds() + actualRate := float64(*total) / numSeconds + + // make sure the rate is less than target with some tiny margin of error + require.Less(t, actualRate, wantRate*1.05) + require.Greater(t, actualRate, wantRate*0.9) + }) +} diff --git a/repo/blob/throttling/throttling_storage.go b/repo/blob/throttling/throttling_storage.go index 316e447cb..74addd2f3 100644 --- a/repo/blob/throttling/throttling_storage.go +++ b/repo/blob/throttling/throttling_storage.go @@ -13,6 +13,16 @@ // if we guess wrong or acquire more. const unknownBlobAcquireLength = 20000000 +// operations supported. +const ( + operationGetBlob = "GetBlob" + operationGetMetadata = "GetMetadata" + operationListBlobs = "ListBlobs" + operationSetTime = "SetTime" + operationPutBlob = "PutBlob" + operationDeleteBlob = "DeleteBlob" +) + // Throttler implements throttling policy by blocking before certain operations are // attempted to ensure we don't exceed the desired rate of operations/bytes uploaded/downloaded. type Throttler interface { @@ -42,7 +52,7 @@ func (s *throttlingStorage) GetBlob(ctx context.Context, id blob.ID, offset, len acquired = unknownBlobAcquireLength } - s.throttler.BeforeOperation(ctx, "GetBlob") + s.throttler.BeforeOperation(ctx, operationGetBlob) s.throttler.BeforeDownload(ctx, acquired) output.Reset() @@ -64,30 +74,30 @@ func (s *throttlingStorage) GetBlob(ctx context.Context, id blob.ID, offset, len } func (s *throttlingStorage) GetMetadata(ctx context.Context, id blob.ID) (blob.Metadata, error) { - s.throttler.BeforeOperation(ctx, "GetMetadata") + s.throttler.BeforeOperation(ctx, operationGetMetadata) return s.Storage.GetMetadata(ctx, id) // nolint:wrapcheck } func (s *throttlingStorage) ListBlobs(ctx context.Context, blobIDPrefix blob.ID, cb func(bm blob.Metadata) error) error { - s.throttler.BeforeOperation(ctx, "ListBlobs") + s.throttler.BeforeOperation(ctx, operationListBlobs) return s.Storage.ListBlobs(ctx, blobIDPrefix, cb) // nolint:wrapcheck } func (s *throttlingStorage) SetTime(ctx context.Context, id blob.ID, t time.Time) error { - s.throttler.BeforeOperation(ctx, "SetTime") + s.throttler.BeforeOperation(ctx, operationSetTime) return s.Storage.SetTime(ctx, id, t) // nolint:wrapcheck } func (s *throttlingStorage) PutBlob(ctx context.Context, id blob.ID, data blob.Bytes, opts blob.PutOptions) error { - s.throttler.BeforeOperation(ctx, "PutBlob") + s.throttler.BeforeOperation(ctx, operationPutBlob) s.throttler.BeforeUpload(ctx, int64(data.Length())) return s.Storage.PutBlob(ctx, id, data, opts) // nolint:wrapcheck } func (s *throttlingStorage) DeleteBlob(ctx context.Context, id blob.ID) error { - s.throttler.BeforeOperation(ctx, "DeleteBlob") + s.throttler.BeforeOperation(ctx, operationDeleteBlob) return s.Storage.DeleteBlob(ctx, id) // nolint:wrapcheck } diff --git a/repo/blob/throttling/token_bucket.go b/repo/blob/throttling/token_bucket.go new file mode 100644 index 000000000..26bee670a --- /dev/null +++ b/repo/blob/throttling/token_bucket.go @@ -0,0 +1,122 @@ +package throttling + +import ( + "context" + "sync" + "time" + + "github.com/pkg/errors" + + "github.com/kopia/kopia/repo/logging" +) + +var log = logging.Module("throttling") + +type tokenBucket struct { + name string + now func() time.Time + sleep func(ctx context.Context, d time.Duration) + + mu sync.Mutex + lastTime time.Time + numTokens float64 + maxTokens float64 + addTokensTimeUnit time.Duration +} + +func (b *tokenBucket) replenishTokens(now time.Time) { + if !b.lastTime.IsZero() { + // add tokens based on passage of time, ensuring we don't exceed maxTokens + elapsed := now.Sub(b.lastTime) + addTokens := b.maxTokens * elapsed.Seconds() / b.addTokensTimeUnit.Seconds() + + b.numTokens += addTokens + if b.numTokens > b.maxTokens { + b.numTokens = b.maxTokens + } + } + + b.lastTime = now +} + +func (b *tokenBucket) sleepDurationBeforeTokenAreAvailable(n float64, now time.Time) time.Duration { + b.mu.Lock() + defer b.mu.Unlock() + + if b.maxTokens == 0 { + return 0 + } + + b.replenishTokens(now) + + // consume N tokens. + b.numTokens -= n + + if b.numTokens >= 0 { + // tokens are immediately available + return 0 + } + + return time.Duration(float64(b.addTokensTimeUnit.Nanoseconds()) * (-b.numTokens / b.maxTokens)) +} + +func (b *tokenBucket) Take(ctx context.Context, n float64) { + d := b.TakeDuration(ctx, n) + if d > 0 { + log(ctx).Debugf("sleeping for %v to refill token bucket %v", d, b.name) + b.sleep(ctx, d) + } +} + +func (b *tokenBucket) TakeDuration(ctx context.Context, n float64) time.Duration { + return b.sleepDurationBeforeTokenAreAvailable(n, b.now()) +} + +func (b *tokenBucket) Return(ctx context.Context, n float64) { + b.mu.Lock() + defer b.mu.Unlock() + + b.numTokens += n + if b.numTokens > b.maxTokens { + b.numTokens = b.maxTokens + } +} + +func (b *tokenBucket) SetLimit(maxTokens float64) error { + b.mu.Lock() + defer b.mu.Unlock() + + if maxTokens < 0 { + return errors.Errorf("limit cannot be negative") + } + + b.maxTokens = maxTokens + b.maxTokens = maxTokens + + if b.numTokens > b.maxTokens { + b.numTokens = b.maxTokens + } + + return nil +} + +func sleepWithContext(ctx context.Context, dur time.Duration) { + t := time.NewTimer(dur) + defer t.Stop() + + select { + case <-ctx.Done(): + case <-t.C: + } +} + +func newTokenBucket(name string, initialTokens, maxTokens float64, addTimeUnit time.Duration) *tokenBucket { + return &tokenBucket{ + name: name, + now: time.Now, // nolint:forbidigo + sleep: sleepWithContext, + numTokens: initialTokens, + maxTokens: maxTokens, + addTokensTimeUnit: addTimeUnit, + } +} diff --git a/repo/blob/throttling/token_bucket_test.go b/repo/blob/throttling/token_bucket_test.go new file mode 100644 index 000000000..088673348 --- /dev/null +++ b/repo/blob/throttling/token_bucket_test.go @@ -0,0 +1,76 @@ +package throttling + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTokenBucket(t *testing.T) { + b := newTokenBucket("test-bucket", 1000, 1000, time.Second) + ctx := context.Background() + + currentTime := time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC) + + verifyTakeTimeElapsed := func(take float64, wantSleep time.Duration) { + t0 := currentTime + + b.Take(ctx, take) + + diff := currentTime.Sub(t0) + + require.Equal(t, wantSleep, diff) + } + + advanceTime := func(dur time.Duration) { + currentTime = currentTime.Add(dur) + } + + b.now = func() time.Time { + return currentTime + } + b.sleep = func(ctx context.Context, d time.Duration) { + currentTime = currentTime.Add(d) + } + + verifyTakeTimeElapsed(0, 0) + require.Equal(t, 1000.0, b.numTokens) + + // we did not sleep and grabbed all tokens. + verifyTakeTimeElapsed(1000, 0) + require.Equal(t, 0.0, b.numTokens) + + // token bucket is empty, consuming 500 will require waiting 0.5 seconds + verifyTakeTimeElapsed(500, 500*time.Millisecond) + require.Equal(t, -500.0, b.numTokens) + + // grabbing zero will reset tokens to zero based on passage of time. + verifyTakeTimeElapsed(0, 0) + require.Equal(t, 0.0, b.numTokens) + + advanceTime(1 * time.Second) + verifyTakeTimeElapsed(0, 0) + require.Equal(t, 1000.0, b.numTokens) + + // token bucket is empty at point, wait a long time to fully replenish. + advanceTime(5 * time.Second) + verifyTakeTimeElapsed(0, 0) + + require.Equal(t, 1000.0, b.numTokens) + + // now we can grab all tokens without sleeping + verifyTakeTimeElapsed(300, 0) + verifyTakeTimeElapsed(700, 0) + verifyTakeTimeElapsed(1000, time.Second) + verifyTakeTimeElapsed(100, 100*time.Millisecond) + + advanceTime(5 * time.Second) + + verifyTakeTimeElapsed(1000, 0) + b.Return(ctx, 2000) + verifyTakeTimeElapsed(1000, 0) + b.Return(ctx, 1000) + verifyTakeTimeElapsed(1000, 0) +} diff --git a/repo/blob/webdav/webdav_options.go b/repo/blob/webdav/webdav_options.go index 29b1ed10a..c2635e006 100644 --- a/repo/blob/webdav/webdav_options.go +++ b/repo/blob/webdav/webdav_options.go @@ -1,6 +1,9 @@ package webdav -import "github.com/kopia/kopia/repo/blob/sharded" +import ( + "github.com/kopia/kopia/repo/blob/sharded" + "github.com/kopia/kopia/repo/blob/throttling" +) // Options defines options for Filesystem-backed storage. type Options struct { @@ -11,4 +14,5 @@ type Options struct { AtomicWrites bool `json:"atomicWrites"` sharded.Options + throttling.Limits } diff --git a/repo/open.go b/repo/open.go index 471b172a2..855ccc377 100644 --- a/repo/open.go +++ b/repo/open.go @@ -2,6 +2,7 @@ import ( "context" + "encoding/json" "os" "path/filepath" "time" @@ -16,6 +17,7 @@ "github.com/kopia/kopia/repo/blob" loggingwrapper "github.com/kopia/kopia/repo/blob/logging" "github.com/kopia/kopia/repo/blob/readonly" + "github.com/kopia/kopia/repo/blob/throttling" "github.com/kopia/kopia/repo/content" "github.com/kopia/kopia/repo/logging" "github.com/kopia/kopia/repo/manifest" @@ -33,6 +35,13 @@ // as valid. const defaultFormatBlobCacheDuration = 15 * time.Minute +// throttlingWindow is the duration window during which the throttling token bucket fully replenishes. +// the maximum number of tokens in the bucket is multiplied by the number of seconds. +const throttlingWindow = 60 * time.Second + +// start with 10% of tokens in the bucket. +const throttleBucketInitialFill = 0.1 + // localCacheIntegrityHMACSecretLength length of HMAC secret protecting local cache items. const localCacheIntegrityHMACSecretLength = 16 @@ -164,6 +173,7 @@ func openDirect(ctx context.Context, configFile string, lc *LocalConfig, passwor } // openWithConfig opens the repository with a given configuration, avoiding the need for a config file. +// nolint:funlen func openWithConfig(ctx context.Context, st blob.Storage, lc *LocalConfig, password string, options *Options, caching *content.CachingOptions, configFile string) (DirectRepository, error) { caching = caching.CloneOrDefault() @@ -222,6 +232,11 @@ func openWithConfig(ctx context.Context, st blob.Storage, lc *LocalConfig, passw cmOpts.RepositoryFormatBytes = nil } + st, throttler, err := addThrottler(ctx, st) + if err != nil { + return nil, errors.Wrap(err, "unable to add throttler") + } + scm, err := content.NewSharedManager(ctx, st, fo, caching, cmOpts) if err != nil { return nil, errors.Wrap(err, "unable to create shared content manager") @@ -243,11 +258,12 @@ func openWithConfig(ctx context.Context, st blob.Storage, lc *LocalConfig, passw } dr := &directRepository{ - cmgr: cm, - omgr: om, - blobs: st, - mmgr: manifests, - sm: scm, + cmgr: cm, + omgr: om, + blobs: st, + mmgr: manifests, + sm: scm, + throttler: throttler, directRepositoryParameters: directRepositoryParameters{ uniqueID: f.UniqueID, cachingOptions: *caching, @@ -264,6 +280,33 @@ func openWithConfig(ctx context.Context, st blob.Storage, lc *LocalConfig, passw return dr, nil } +func addThrottler(ctx context.Context, st blob.Storage) (blob.Storage, throttling.SettableThrottler, error) { + throttler, err := throttling.NewThrottler( + throttlingLimitsFromConnectionInfo(ctx, st.ConnectionInfo()), throttlingWindow, throttleBucketInitialFill) + if err != nil { + return nil, nil, errors.Wrap(err, "unable to create throttler") + } + + return throttling.NewWrapper(st, throttler), throttler, nil +} + +func throttlingLimitsFromConnectionInfo(ctx context.Context, ci blob.ConnectionInfo) throttling.Limits { + v, err := json.Marshal(ci.Config) + if err != nil { + return throttling.Limits{} + } + + var l throttling.Limits + + if err := json.Unmarshal(v, &l); err != nil { + return throttling.Limits{} + } + + log(ctx).Debugw("throttling limits from connection info", "limits", l) + + return l +} + func writeCacheMarker(cacheDir string) error { if cacheDir == "" { return nil diff --git a/repo/repository.go b/repo/repository.go index 480ce5ebf..e7da35530 100644 --- a/repo/repository.go +++ b/repo/repository.go @@ -10,6 +10,7 @@ "github.com/kopia/kopia/internal/clock" "github.com/kopia/kopia/repo/blob" + "github.com/kopia/kopia/repo/blob/throttling" "github.com/kopia/kopia/repo/content" "github.com/kopia/kopia/repo/manifest" "github.com/kopia/kopia/repo/object" @@ -64,6 +65,8 @@ type DirectRepository interface { DeriveKey(purpose []byte, keyLength int) []byte Token(password string) (string, error) + Throttler() throttling.SettableThrottler + DisableIndexRefresh() } @@ -99,6 +102,8 @@ type directRepository struct { mmgr *manifest.Manager sm *content.SharedManager + throttler throttling.SettableThrottler + closed chan struct{} } @@ -124,6 +129,11 @@ func (r *directRepository) BlobStorage() blob.Storage { return r.blobs } +// Throttler returns the blob storage throttler. +func (r *directRepository) Throttler() throttling.SettableThrottler { + return r.throttler +} + // ContentManager returns the content manager. func (r *directRepository) ContentManager() *content.WriteManager { return r.cmgr diff --git a/tests/end_to_end_test/server_start_test.go b/tests/end_to_end_test/server_start_test.go index 85bd99183..cd43c72a1 100644 --- a/tests/end_to_end_test/server_start_test.go +++ b/tests/end_to_end_test/server_start_test.go @@ -63,7 +63,7 @@ func TestServerStart(t *testing.T) { defer e.RunAndExpectSuccess(t, "repo", "disconnect") - e.RunAndExpectSuccess(t, "repo", "create", "filesystem", "--path", e.RepoDir, "--override-hostname=fake-hostname", "--override-username=fake-username") + e.RunAndExpectSuccess(t, "repo", "create", "filesystem", "--path", e.RepoDir, "--override-hostname=fake-hostname", "--override-username=fake-username", "--max-upload-speed=10000000001") e.RunAndExpectSuccess(t, "snapshot", "create", sharedTestDataDir1) e.RunAndExpectSuccess(t, "snapshot", "create", sharedTestDataDir1) @@ -100,6 +100,20 @@ func TestServerStart(t *testing.T) { st := verifyServerConnected(t, cli, true) require.Equal(t, "filesystem", st.Storage) + limits, err := serverapi.GetThrottlingLimits(ctx, cli) + require.NoError(t, err) + + // make sure limits are preserved + require.Equal(t, 10000000001.0, limits.UploadBytesPerSecond) + + // change the limit via the API. + limits.UploadBytesPerSecond++ + require.NoError(t, serverapi.SetThrottlingLimits(ctx, cli, limits)) + + limits, err = serverapi.GetThrottlingLimits(ctx, cli) + require.NoError(t, err) + require.Equal(t, 10000000002.0, limits.UploadBytesPerSecond) + sources := verifySourceCount(t, cli, nil, 1) require.Equal(t, sharedTestDataDir1, sources[0].Source.Path)