From e80f5536c318e7c21161a09a08f6e185d36bd045 Mon Sep 17 00:00:00 2001 From: Jarek Kowalski Date: Thu, 12 Mar 2020 08:27:44 -0700 Subject: [PATCH] =?UTF-8?q?performance:=20plumbed=20through=20output=20buf?= =?UTF-8?q?fer=20to=20encryption=20and=20hashing,=E2=80=A6=20(#333)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * performance: plumbed through output buffer to encryption and hashing, so that the caller can pre-allocate/reuse it * testing: fixed how we do comparison of byte slices to account for possible nils, which can be returned from encryption --- Makefile | 1 + cli/command_benchmark_crypto.go | 6 +- internal/bufcache/bufcache.go | 78 +++++++++++++++++++ internal/bufcache/bufcache_test.go | 31 ++++++++ internal/hmac/hmac.go | 4 +- repo/content/content_cache_test.go | 2 +- repo/content/content_formatter_test.go | 11 ++- repo/content/content_index_recovery.go | 4 +- repo/content/content_manager.go | 49 +++++++++++- repo/content/content_manager_lock_free.go | 47 +++++------ repo/content/content_manager_test.go | 4 +- repo/content/packindex_test.go | 4 +- repo/encryption/aead_helpers.go | 16 ++-- .../aes256_gcm_hmac_sha256_encryptor.go | 28 +++++-- ...chacha20_poly1305_hmac_sha256_encryptor.go | 29 +++++-- repo/encryption/deprecated_ctr_encryptor.go | 15 ++-- repo/encryption/deprecated_salsa_encryptor.go | 14 ++-- repo/encryption/encryption.go | 32 ++++++-- repo/encryption/encryption_test.go | 29 ++++--- repo/encryption/null_encryptor.go | 8 +- repo/hashing/hashing.go | 37 +++++++-- repo/hashing/hashing_test.go | 11 ++- tests/stress_test/stress_test.go | 4 +- 23 files changed, 354 insertions(+), 110 deletions(-) create mode 100644 internal/bufcache/bufcache.go create mode 100644 internal/bufcache/bufcache_test.go diff --git a/Makefile b/Makefile index 8d68fb66b..9ef450be5 100644 --- a/Makefile +++ b/Makefile @@ -189,6 +189,7 @@ ifneq ($(uname),Windows) -e github.com/kopia/kopia/internal/blobtesting \ -e github.com/kopia/kopia/internal/repotesting \ -e github.com/kopia/kopia/internal/testlogging \ + -e github.com/kopia/kopia/internal/bufcache \ -e github.com/kopia/kopia/internal/hmac \ -e github.com/kopia/kopia/internal/faketime \ -e github.com/kopia/kopia/internal/testutil \ diff --git a/cli/command_benchmark_crypto.go b/cli/command_benchmark_crypto.go index acde14a02..2958024f0 100644 --- a/cli/command_benchmark_crypto.go +++ b/cli/command_benchmark_crypto.go @@ -53,9 +53,11 @@ type benchResult struct { t0 := time.Now() hashCount := *benchmarkCryptoRepeat + hashOutput := make([]byte, 0, 64) + for i := 0; i < hashCount; i++ { - contentID := h(data) - if _, encerr := e.Encrypt(data, contentID); encerr != nil { + contentID := h(hashOutput[:0], data) + if _, encerr := e.Encrypt(nil, data, contentID); encerr != nil { printStderr("encryption failed: %v\n", encerr) break } diff --git a/internal/bufcache/bufcache.go b/internal/bufcache/bufcache.go new file mode 100644 index 000000000..cd057fbcb --- /dev/null +++ b/internal/bufcache/bufcache.go @@ -0,0 +1,78 @@ +// Package bufcache allocates and recycles byte slices used as buffers. +package bufcache + +import ( + "sync" +) + +type poolWithCapacity struct { + capacity int + pool *sync.Pool +} + +// pools keep track of sync.Pools holding pointers to byte slices of exactly the provided capacity. +// when allocating, we pick from the smallest pool that fits. +var pools = []poolWithCapacity{ + {1 << 8, &sync.Pool{}}, // 256 B + {1 << 10, &sync.Pool{}}, // 1 KB + {1 << 12, &sync.Pool{}}, // 4 KB + {1 << 14, &sync.Pool{}}, // 16 KB + {1 << 16, &sync.Pool{}}, // 64 KB + {1 << 18, &sync.Pool{}}, // 256 KB + {1 << 20, &sync.Pool{}}, // 1 MB + {1 << 21, &sync.Pool{}}, // 2 MB + {1 << 22, &sync.Pool{}}, // 4 MB + {1 << 23, &sync.Pool{}}, // 8 MB + {1 << 24, &sync.Pool{}}, // 16 MB + {1 << 25, &sync.Pool{}}, // 32 MB +} + +// EmptyBytesWithCapacity returns slice of length 0 with >= given capacity. +func EmptyBytesWithCapacity(capacity int) []byte { + if p, ok := findPoolWithSize(capacity); ok { + return getOrAllocate(p.pool, p.capacity) + } + + // beyond largest bucket, allocate + return make([]byte, 0, capacity) +} + +// Clone clones given slice onto a slice from the cache. +func Clone(b []byte) []byte { + return append(EmptyBytesWithCapacity(len(b)), b...) +} + +// Return returns the given slice back to the pool. +func Return(b []byte) { + if p, ok := findPoolWithSize(cap(b)); ok && p.capacity == cap(b) { + p.pool.Put(&b) + } +} + +func findPoolWithSize(capacity int) (poolWithCapacity, bool) { + // quick binary search to find the right pool bucket + l, r := 0, len(pools) + + for l < r { + if m := (l + r) >> 1; pools[m].capacity < capacity { + l = m + 1 + } else { + r = m + } + } + + if l < len(pools) { + return pools[l], true + } + + return poolWithCapacity{}, false +} + +func getOrAllocate(p *sync.Pool, capacity int) []byte { + v := p.Get() + if v == nil { + return make([]byte, 0, capacity) + } + + return (*v.(*[]byte))[:0] +} diff --git a/internal/bufcache/bufcache_test.go b/internal/bufcache/bufcache_test.go new file mode 100644 index 000000000..7b42c5365 --- /dev/null +++ b/internal/bufcache/bufcache_test.go @@ -0,0 +1,31 @@ +package bufcache_test + +import ( + "testing" + + "github.com/kopia/kopia/internal/bufcache" +) + +func TestBufCache(t *testing.T) { + cases := []struct { + requestCap int + wantResultCap int + }{ + {0, 256}, + {1, 256}, + {256, 256}, + {257, 1024}, + {1024, 1024}, + {1025, 4096}, + {1 << 24, 1 << 24}, // 16 MB + {1 << 25, 1 << 25}, // 32 MB + {1<<25 + 3, 1<<25 + 3}, // 32 MB + 3, not pooled anymore + } + + for _, tc := range cases { + result := bufcache.EmptyBytesWithCapacity(tc.requestCap) + if got, want := cap(result), tc.wantResultCap; got != want { + t.Errorf("got invalid capacity of buffer: %v, want %v", got, want) + } + } +} diff --git a/internal/hmac/hmac.go b/internal/hmac/hmac.go index eda793d86..268a010af 100644 --- a/internal/hmac/hmac.go +++ b/internal/hmac/hmac.go @@ -27,7 +27,9 @@ func VerifyAndStrip(b, secret []byte) ([]byte, error) { h := hmac.New(sha256.New, secret) h.Write(data) // nolint:errcheck - validSignature := h.Sum(nil) + + var sigBuf [32]byte + validSignature := h.Sum(sigBuf[:0]) if len(signature) != len(validSignature) { return nil, errors.New("invalid signature length") diff --git a/repo/content/content_cache_test.go b/repo/content/content_cache_test.go index 8715b68bd..d0543e9ec 100644 --- a/repo/content/content_cache_test.go +++ b/repo/content/content_cache_test.go @@ -148,7 +148,7 @@ func verifyContentCache(t *testing.T, cache *contentCache) { } else if err != nil && err.Error() != tc.err.Error() { t.Errorf("unexpected error for %v: %+v, wanted %+v", tc.cacheKey, err, tc.err) } - if !reflect.DeepEqual(v, tc.expected) { + if !bytes.Equal(v, tc.expected) { t.Errorf("unexpected data for %v: %x, wanted %x", tc.cacheKey, v, tc.expected) } } diff --git a/repo/content/content_formatter_test.go b/repo/content/content_formatter_test.go index f4811b113..a5015b96d 100644 --- a/repo/content/content_formatter_test.go +++ b/repo/content/content_formatter_test.go @@ -5,7 +5,6 @@ "context" cryptorand "crypto/rand" "crypto/sha1" - "reflect" "strings" "testing" "time" @@ -68,14 +67,14 @@ func TestFormatters(t *testing.T) { return } - contentID := h(data) + contentID := h(nil, data) - cipherText, err := e.Encrypt(data, contentID) + cipherText, err := e.Encrypt(nil, data, contentID) if err != nil || cipherText == nil { t.Errorf("invalid response from Encrypt: %v %v", cipherText, err) } - plainText, err := e.Decrypt(cipherText, contentID) + plainText, err := e.Decrypt(nil, cipherText, contentID) if err != nil || plainText == nil { t.Errorf("invalid response from Decrypt: %v %v", plainText, err) } @@ -132,8 +131,8 @@ func verifyEndToEndFormatter(ctx context.Context, t *testing.T, hashAlgo, encryp return } - if got, want := b2, b; !reflect.DeepEqual(got, want) { - t.Errorf("content %q data mismatch: got %x (nil:%v), wanted %x (nil:%v)", contentID, got, got == nil, want, want == nil) + if got, want := b2, b; !bytes.Equal(got, want) { + t.Errorf("content %q data mismatch: got %x, wanted %x", contentID, got, want) return } } diff --git a/repo/content/content_index_recovery.go b/repo/content/content_index_recovery.go index a2669cf7d..740a7ac5e 100644 --- a/repo/content/content_index_recovery.go +++ b/repo/content/content_index_recovery.go @@ -176,9 +176,9 @@ func (bm *lockFreeManager) appendPackFileIndexRecoveryData(ctx context.Context, return nil, err } - localIndexIV := bm.hashData(localIndex) + localIndexIV := bm.hashData(nil, localIndex) - encryptedLocalIndex, err := bm.encryptor.Encrypt(localIndex, localIndexIV) + encryptedLocalIndex, err := bm.encryptor.Encrypt(nil, localIndex, localIndexIV) if err != nil { return nil, err } diff --git a/repo/content/content_manager.go b/repo/content/content_manager.go index da5bdf5a8..372c05f60 100644 --- a/repo/content/content_manager.go +++ b/repo/content/content_manager.go @@ -14,6 +14,7 @@ "github.com/pkg/errors" "go.opencensus.io/stats" + "github.com/kopia/kopia/internal/bufcache" "github.com/kopia/kopia/repo/blob" "github.com/kopia/kopia/repo/logging" ) @@ -27,6 +28,8 @@ const ( PackBlobIDPrefixRegular blob.ID = "p" PackBlobIDPrefixSpecial blob.ID = "q" + + maxHashSize = 64 ) // PackBlobIDPrefixes contains all possible prefixes for pack blobs. @@ -78,6 +81,7 @@ type Manager struct { disableIndexFlushCount int flushPackIndexesAfter time.Time // time when those indexes should be flushed closed chan struct{} + bufferPool sync.Pool lockFreeManager } @@ -148,7 +152,7 @@ func (bm *Manager) deletePreexistingContent(ci Info) { func (bm *Manager) addToPackUnlocked(ctx context.Context, contentID ID, data []byte, isDeleted bool) error { prefix := packPrefixForContentID(contentID) - data = cloneBytes(data) + data = bufcache.Clone(data) bm.lock() @@ -336,6 +340,11 @@ func (bm *Manager) writePackAndAddToIndex(ctx context.Context, pp *pendingPackIn bm.packIndexBuilder.Add(*info) } + // return all cloned memory back to the buffer so that it can be used again + for _, it := range pp.currentPackItems { + bufcache.Return(it.Payload) + } + return nil } @@ -353,7 +362,9 @@ func (bm *Manager) prepareAndWritePackInternal(ctx context.Context, pp *pendingP packFile := blob.ID(fmt.Sprintf("%v%x", pp.prefix, contentID)) - contentData, packFileIndex, err := bm.preparePackDataContent(ctx, pp, packFile) + estimated := bm.estimatePackBlobSize(pp) + + contentData, packFileIndex, err := bm.preparePackDataContent(ctx, bufcache.EmptyBytesWithCapacity(estimated), pp, packFile) if err != nil { return nil, errors.Wrap(err, "error preparing data content") } @@ -364,11 +375,36 @@ func (bm *Manager) prepareAndWritePackInternal(ctx context.Context, pp *pendingP } formatLog(ctx).Debugf("wrote pack file: %v (%v bytes)", packFile, len(contentData)) + bufcache.Return(contentData) + } + + if estimated < len(contentData) { + log(ctx).Warningf("did not estimate content length: %v, predicted %v", len(contentData), estimated) } return packFileIndex, nil } +// estimatePackBlobSize estimates the size of the buffer to hold the pack blob. +// we use this to preallocate buffer and avoid wasteful reallocations. +// this function can overshoot, but best not to overshoot by too much. +func (bm *Manager) estimatePackBlobSize(pp *pendingPackInfo) int { + const ( + estimatedPackIndexOverhead = 10000 + estimatedPerItemOverhead = 64 + ) + + estimateCapacity := 0 + for _, pp := range pp.currentPackItems { + estimateCapacity += int(pp.Length) + estimatedPerItemOverhead + } + + estimateCapacity += len(bm.repositoryFormatBytes) + estimateCapacity += estimatedPackIndexOverhead + + return estimateCapacity +} + func removePendingPack(slice []*pendingPackInfo, pp *pendingPackInfo) []*pendingPackInfo { result := slice[:0] @@ -470,7 +506,9 @@ func (bm *Manager) WriteContent(ctx context.Context, data []byte, prefix ID) (ID return "", err } - contentID := prefix + ID(hex.EncodeToString(bm.hashData(data))) + var hashOutput [maxHashSize]byte + + contentID := prefix + ID(hex.EncodeToString(bm.hashData(hashOutput[:0], data))) // content already tracked if bi, err := bm.getContentInfo(contentID); err == nil { @@ -664,6 +702,11 @@ func newManagerWithOptions(ctx context.Context, st blob.Storage, f *FormattingOp pendingPacks: map[blob.ID]*pendingPackInfo{}, packIndexBuilder: make(packIndexBuilder), closed: make(chan struct{}), + bufferPool: sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, + }, } if err := m.CompactIndexes(ctx, autoCompactionOptions); err != nil { diff --git a/repo/content/content_manager_lock_free.go b/repo/content/content_manager_lock_free.go index ac85118c3..915a7ce4e 100644 --- a/repo/content/content_manager_lock_free.go +++ b/repo/content/content_manager_lock_free.go @@ -48,13 +48,13 @@ type lockFreeManager struct { repositoryFormatBytes []byte } -func (bm *lockFreeManager) maybeEncryptContentDataForPacking(data []byte, contentID ID) ([]byte, error) { +func (bm *lockFreeManager) maybeEncryptContentDataForPacking(output, data []byte, contentID ID) ([]byte, error) { iv, err := getPackedContentIV(contentID) if err != nil { return nil, errors.Wrapf(err, "unable to get packed content IV for %q", contentID) } - return bm.encryptor.Encrypt(data, iv) + return bm.encryptor.Encrypt(output, data, iv) } func appendRandomBytes(b []byte, count int) ([]byte, error) { @@ -235,7 +235,7 @@ func (bm *lockFreeManager) getContentDataUnlocked(ctx context.Context, bi *Info) } func (bm *lockFreeManager) decryptAndVerify(encrypted, iv []byte) ([]byte, error) { - decrypted, err := bm.encryptor.Decrypt(encrypted, iv) + decrypted, err := bm.encryptor.Decrypt(nil, encrypted, iv) if err != nil { return nil, errors.Wrap(err, "decrypt") } @@ -252,10 +252,10 @@ func (bm *lockFreeManager) decryptAndVerify(encrypted, iv []byte) ([]byte, error return decrypted, bm.verifyChecksum(decrypted, iv) } -func (bm *lockFreeManager) preparePackDataContent(ctx context.Context, pp *pendingPackInfo, packFile blob.ID) ([]byte, packIndexBuilder, error) { +func (bm *lockFreeManager) preparePackDataContent(ctx context.Context, contentData []byte, pp *pendingPackInfo, packFile blob.ID) ([]byte, packIndexBuilder, error) { formatLog(ctx).Debugf("preparing content data with %v items", len(pp.currentPackItems)) - contentData, err := appendRandomBytes(append([]byte(nil), bm.repositoryFormatBytes...), rand.Intn(bm.maxPreambleLength-bm.minPreambleLength+1)+bm.minPreambleLength) + contentData, err := appendRandomBytes(append(contentData, bm.repositoryFormatBytes...), rand.Intn(bm.maxPreambleLength-bm.minPreambleLength+1)+bm.minPreambleLength) if err != nil { return nil, nil, errors.Wrap(err, "unable to prepare content preamble") } @@ -263,6 +263,8 @@ func (bm *lockFreeManager) preparePackDataContent(ctx context.Context, pp *pendi packFileIndex := packIndexBuilder{} haveContent := false + var encryptedTmp []byte + for contentID, info := range pp.currentPackItems { if info.Payload == nil { // no payload, it's a deletion of a previously-committed content. @@ -272,9 +274,7 @@ func (bm *lockFreeManager) preparePackDataContent(ctx context.Context, pp *pendi haveContent = true - var encrypted []byte - - encrypted, err = bm.maybeEncryptContentDataForPacking(info.Payload, info.ID) + encryptedTmp, err = bm.maybeEncryptContentDataForPacking(encryptedTmp[:0], info.Payload, info.ID) if err != nil { return nil, nil, errors.Wrapf(err, "unable to encrypt %q", info.ID) } @@ -287,15 +287,15 @@ func (bm *lockFreeManager) preparePackDataContent(ctx context.Context, pp *pendi FormatVersion: byte(bm.writeFormatVersion), PackBlobID: packFile, PackOffset: uint32(len(contentData)), - Length: uint32(len(encrypted)), + Length: uint32(len(encryptedTmp)), TimestampSeconds: info.TimestampSeconds, }) if contentID.HasPrefix() { - bm.metadataCache.put(ctx, cacheKey(contentID), cloneBytes(encrypted)) + bm.metadataCache.put(ctx, cacheKey(contentID), cloneBytes(encryptedTmp)) } - contentData = append(contentData, encrypted...) + contentData = append(contentData, encryptedTmp...) } if len(packFileIndex) == 0 { @@ -341,7 +341,7 @@ func (bm *lockFreeManager) getIndexBlobInternal(ctx context.Context, blobID blob bm.Stats.readContent(len(payload)) - payload, err = bm.encryptor.Decrypt(payload, iv) + payload, err = bm.encryptor.Decrypt(nil, payload, iv) bm.Stats.decrypted(len(payload)) if err != nil { @@ -376,14 +376,15 @@ func (bm *lockFreeManager) writePackFileNotLocked(ctx context.Context, packFile return bm.st.PutBlob(ctx, packFile, data) } -func (bm *lockFreeManager) encryptAndWriteContentNotLocked(ctx context.Context, data []byte, prefix blob.ID) (blob.ID, error) { - hash := bm.hashData(data) +func (bm *lockFreeManager) encryptAndWriteBlobNotLocked(ctx context.Context, data []byte, prefix blob.ID) (blob.ID, error) { + var hashOutput [maxHashSize]byte + + hash := bm.hashData(hashOutput[:0], data) blobID := prefix + blob.ID(hex.EncodeToString(hash)) - // Encrypt the content in-place. bm.Stats.encrypted(len(data)) - data2, err := bm.encryptor.Encrypt(data, hash) + data2, err := bm.encryptor.Encrypt(nil, data, hash) if err != nil { return "", err } @@ -398,20 +399,22 @@ func (bm *lockFreeManager) encryptAndWriteContentNotLocked(ctx context.Context, return blobID, nil } -func (bm *lockFreeManager) hashData(data []byte) []byte { +func (bm *lockFreeManager) hashData(output, data []byte) []byte { // Hash the content and compute encryption key. - contentID := bm.hasher(data) + contentID := bm.hasher(output, data) bm.Stats.hashedContent(len(data)) return contentID } func (bm *lockFreeManager) writePackIndexesNew(ctx context.Context, data []byte) (blob.ID, error) { - return bm.encryptAndWriteContentNotLocked(ctx, data, newIndexBlobPrefix) + return bm.encryptAndWriteBlobNotLocked(ctx, data, newIndexBlobPrefix) } func (bm *lockFreeManager) verifyChecksum(data, contentID []byte) error { - expected := bm.hasher(data) + var hashOutput [maxHashSize]byte + + expected := bm.hasher(hashOutput[:0], data) expected = expected[len(expected)-aes.BlockSize:] if !bytes.HasSuffix(contentID, expected) { @@ -437,9 +440,9 @@ func CreateHashAndEncryptor(f *FormattingOptions) (hashing.HashFunc, encryption. return nil, nil, errors.Wrap(err, "unable to create encryptor") } - contentID := h(nil) + contentID := h(nil, nil) - _, err = e.Encrypt(nil, contentID) + _, err = e.Encrypt(nil, nil, contentID) if err != nil { return nil, nil, errors.Wrap(err, "invalid encryptor") } diff --git a/repo/content/content_manager_test.go b/repo/content/content_manager_test.go index c4e5317a6..3f1bcde57 100644 --- a/repo/content/content_manager_test.go +++ b/repo/content/content_manager_test.go @@ -1245,7 +1245,7 @@ func verifyContentManagerDataSet(ctx context.Context, t *testing.T, mgr *Manager continue } - if !reflect.DeepEqual(v, originalPayload) { + if !bytes.Equal(v, originalPayload) { t.Errorf("payload for %q does not match original: %v", v, originalPayload) } } @@ -1307,7 +1307,7 @@ func verifyContent(ctx context.Context, t *testing.T, bm *Manager, contentID ID, return } - if got, want := b2, b; !reflect.DeepEqual(got, want) { + if got, want := b2, b; !bytes.Equal(got, want) { t.Errorf("content %q data mismatch: got %x (nil:%v), wanted %x (nil:%v)", contentID, got, got == nil, want, want == nil) } diff --git a/repo/content/packindex_test.go b/repo/content/packindex_test.go index 9a1ed848c..9111cbc69 100644 --- a/repo/content/packindex_test.go +++ b/repo/content/packindex_test.go @@ -117,11 +117,11 @@ func TestPackIndex(t *testing.T) { data2 := buf2.Bytes() data3 := buf3.Bytes() - if !reflect.DeepEqual(data1, data2) { + if !bytes.Equal(data1, data2) { t.Errorf("builder output not stable: %x vs %x", hex.Dump(data1), hex.Dump(data2)) } - if !reflect.DeepEqual(data2, data3) { + if !bytes.Equal(data2, data3) { t.Errorf("builder output not stable: %x vs %x", hex.Dump(data2), hex.Dump(data3)) } diff --git a/repo/encryption/aead_helpers.go b/repo/encryption/aead_helpers.go index 478376495..d9355f9c8 100644 --- a/repo/encryption/aead_helpers.go +++ b/repo/encryption/aead_helpers.go @@ -8,9 +8,15 @@ ) // aeadSealWithRandomNonce returns AEAD-sealed content prepended with random nonce. -func aeadSealWithRandomNonce(a cipher.AEAD, plaintext, contentID []byte) ([]byte, error) { - // pre-allocate a slice with len()=size of a nonce, and cap() for the entire ciphertext - result := make([]byte, a.NonceSize(), len(plaintext)+a.NonceSize()+a.Overhead()) +func aeadSealWithRandomNonce(result []byte, a cipher.AEAD, plaintext, contentID []byte) ([]byte, error) { + resultLen := len(plaintext) + a.NonceSize() + a.Overhead() + + if cap(result) < resultLen { + // result slice too small, make a new one + result = make([]byte, 0, resultLen) + } + + result = result[0:a.NonceSize()] n, err := rand.Read(result) if err != nil { @@ -25,10 +31,10 @@ func aeadSealWithRandomNonce(a cipher.AEAD, plaintext, contentID []byte) ([]byte } // aeadOpenPrefixedWithNonce opens AEAD-protected content, assuming first bytes are the nonce. -func aeadOpenPrefixedWithNonce(a cipher.AEAD, ciphertext, contentID []byte) ([]byte, error) { +func aeadOpenPrefixedWithNonce(output []byte, a cipher.AEAD, ciphertext, contentID []byte) ([]byte, error) { if len(ciphertext) < a.NonceSize() { return nil, errors.Errorf("ciphertext too short") } - return a.Open(nil, ciphertext[0:a.NonceSize()], ciphertext[a.NonceSize():], contentID) + return a.Open(output[:0], ciphertext[0:a.NonceSize()], ciphertext[a.NonceSize():], contentID) } diff --git a/repo/encryption/aes256_gcm_hmac_sha256_encryptor.go b/repo/encryption/aes256_gcm_hmac_sha256_encryptor.go index 3817e84b3..6abe81d6f 100644 --- a/repo/encryption/aes256_gcm_hmac_sha256_encryptor.go +++ b/repo/encryption/aes256_gcm_hmac_sha256_encryptor.go @@ -5,22 +5,28 @@ "crypto/cipher" "crypto/hmac" "crypto/sha256" + "hash" + "sync" "github.com/pkg/errors" ) type aes256GCMHmacSha256 struct { - keyDerivationSecret []byte + hmacPool *sync.Pool } // aeadForContent returns cipher.AEAD using key derived from a given contentID. func (e aes256GCMHmacSha256) aeadForContent(contentID []byte) (cipher.AEAD, error) { - h := hmac.New(sha256.New, e.keyDerivationSecret) + h := e.hmacPool.Get().(hash.Hash) + defer e.hmacPool.Put(h) + h.Reset() + if _, err := h.Write(contentID); err != nil { return nil, errors.Wrap(err, "unable to derive encryption key") } - key := h.Sum(nil) + var hashBuf [32]byte + key := h.Sum(hashBuf[:0]) c, err := aes.NewCipher(key) if err != nil { @@ -30,22 +36,22 @@ func (e aes256GCMHmacSha256) aeadForContent(contentID []byte) (cipher.AEAD, erro return cipher.NewGCM(c) } -func (e aes256GCMHmacSha256) Decrypt(input, contentID []byte) ([]byte, error) { +func (e aes256GCMHmacSha256) Decrypt(output, input, contentID []byte) ([]byte, error) { a, err := e.aeadForContent(contentID) if err != nil { return nil, err } - return aeadOpenPrefixedWithNonce(a, input, contentID) + return aeadOpenPrefixedWithNonce(output, a, input, contentID) } -func (e aes256GCMHmacSha256) Encrypt(input, contentID []byte) ([]byte, error) { +func (e aes256GCMHmacSha256) Encrypt(output, input, contentID []byte) ([]byte, error) { a, err := e.aeadForContent(contentID) if err != nil { return nil, err } - return aeadSealWithRandomNonce(a, input, contentID) + return aeadSealWithRandomNonce(output, a, input, contentID) } func (e aes256GCMHmacSha256) IsAuthenticated() bool { @@ -63,6 +69,12 @@ func init() { return nil, err } - return aes256GCMHmacSha256{keyDerivationSecret}, nil + hmacPool := &sync.Pool{ + New: func() interface{} { + return hmac.New(sha256.New, keyDerivationSecret) + }, + } + + return aes256GCMHmacSha256{hmacPool}, nil }) } diff --git a/repo/encryption/chacha20_poly1305_hmac_sha256_encryptor.go b/repo/encryption/chacha20_poly1305_hmac_sha256_encryptor.go index 584212877..a05795a00 100644 --- a/repo/encryption/chacha20_poly1305_hmac_sha256_encryptor.go +++ b/repo/encryption/chacha20_poly1305_hmac_sha256_encryptor.go @@ -4,43 +4,50 @@ "crypto/cipher" "crypto/hmac" "crypto/sha256" + "hash" + "sync" "github.com/pkg/errors" "golang.org/x/crypto/chacha20poly1305" ) type chacha20poly1305hmacSha256Encryptor struct { - keyDerivationSecret []byte + hmacPool *sync.Pool } // aeadForContent returns cipher.AEAD using key derived from a given contentID. func (e chacha20poly1305hmacSha256Encryptor) aeadForContent(contentID []byte) (cipher.AEAD, error) { - h := hmac.New(sha256.New, e.keyDerivationSecret) + h := e.hmacPool.Get().(hash.Hash) + defer e.hmacPool.Put(h) + + h.Reset() + if _, err := h.Write(contentID); err != nil { return nil, errors.Wrap(err, "unable to derive encryption key") } - key := h.Sum(nil) + var hashBuf [32]byte + key := h.Sum(hashBuf[:0]) return chacha20poly1305.New(key) } -func (e chacha20poly1305hmacSha256Encryptor) Decrypt(input, contentID []byte) ([]byte, error) { +func (e chacha20poly1305hmacSha256Encryptor) Decrypt(output, input, contentID []byte) ([]byte, error) { a, err := e.aeadForContent(contentID) if err != nil { return nil, err } - return aeadOpenPrefixedWithNonce(a, input, contentID) + return aeadOpenPrefixedWithNonce(output, a, input, contentID) } -func (e chacha20poly1305hmacSha256Encryptor) Encrypt(input, contentID []byte) ([]byte, error) { +func (e chacha20poly1305hmacSha256Encryptor) Encrypt(output, input, contentID []byte) ([]byte, error) { a, err := e.aeadForContent(contentID) if err != nil { return nil, err } - return aeadSealWithRandomNonce(a, input, contentID) + return aeadSealWithRandomNonce(output, a, input, contentID) } func (e chacha20poly1305hmacSha256Encryptor) IsAuthenticated() bool { @@ -58,6 +65,12 @@ func init() { return nil, err } - return chacha20poly1305hmacSha256Encryptor{keyDerivationSecret}, nil + hmacPool := &sync.Pool{ + New: func() interface{} { + return hmac.New(sha256.New, keyDerivationSecret) + }, + } + + return chacha20poly1305hmacSha256Encryptor{hmacPool}, nil }) } diff --git a/repo/encryption/deprecated_ctr_encryptor.go b/repo/encryption/deprecated_ctr_encryptor.go index 399e47cb2..3acb4cc95 100644 --- a/repo/encryption/deprecated_ctr_encryptor.go +++ b/repo/encryption/deprecated_ctr_encryptor.go @@ -12,12 +12,12 @@ type ctrEncryptor struct { createCipher func() (cipher.Block, error) } -func (fi ctrEncryptor) Encrypt(plainText, contentID []byte) ([]byte, error) { - return symmetricEncrypt(fi.createCipher, contentID, plainText) +func (fi ctrEncryptor) Encrypt(output, plainText, contentID []byte) ([]byte, error) { + return symmetricEncrypt(output, fi.createCipher, contentID, plainText) } -func (fi ctrEncryptor) Decrypt(cipherText, contentID []byte) ([]byte, error) { - return symmetricEncrypt(fi.createCipher, contentID, cipherText) +func (fi ctrEncryptor) Decrypt(output, cipherText, contentID []byte) ([]byte, error) { + return symmetricEncrypt(output, fi.createCipher, contentID, cipherText) } func (fi ctrEncryptor) IsAuthenticated() bool { @@ -28,7 +28,7 @@ func (fi ctrEncryptor) IsDeprecated() bool { return true } -func symmetricEncrypt(createCipher func() (cipher.Block, error), iv, b []byte) ([]byte, error) { +func symmetricEncrypt(output []byte, createCipher func() (cipher.Block, error), iv, b []byte) ([]byte, error) { blockCipher, err := createCipher() if err != nil { return nil, err @@ -39,8 +39,9 @@ func symmetricEncrypt(createCipher func() (cipher.Block, error), iv, b []byte) ( } ctr := cipher.NewCTR(blockCipher, iv[0:blockCipher.BlockSize()]) - result := make([]byte, len(b)) - ctr.XORKeyStream(result, b) + + result, out := sliceForAppend(output, len(b)) + ctr.XORKeyStream(out, b) return result, nil } diff --git a/repo/encryption/deprecated_salsa_encryptor.go b/repo/encryption/deprecated_salsa_encryptor.go index efc814f84..0b2c8eec2 100644 --- a/repo/encryption/deprecated_salsa_encryptor.go +++ b/repo/encryption/deprecated_salsa_encryptor.go @@ -20,7 +20,7 @@ type salsaEncryptor struct { hmacSecret []byte } -func (s salsaEncryptor) Decrypt(input, contentID []byte) ([]byte, error) { +func (s salsaEncryptor) Decrypt(output, input, contentID []byte) ([]byte, error) { if s.hmacSecret != nil { var err error @@ -30,11 +30,11 @@ func (s salsaEncryptor) Decrypt(input, contentID []byte) ([]byte, error) { } } - return s.encryptDecrypt(input, contentID) + return s.encryptDecrypt(output, input, contentID) } -func (s salsaEncryptor) Encrypt(input, contentID []byte) ([]byte, error) { - v, err := s.encryptDecrypt(input, contentID) +func (s salsaEncryptor) Encrypt(output, input, contentID []byte) ([]byte, error) { + v, err := s.encryptDecrypt(output, input, contentID) if err != nil { return nil, errors.Wrap(err, "decrypt") } @@ -50,14 +50,14 @@ func (s salsaEncryptor) IsAuthenticated() bool { return s.hmacSecret != nil } -func (s salsaEncryptor) encryptDecrypt(input, contentID []byte) ([]byte, error) { +func (s salsaEncryptor) encryptDecrypt(output, input, contentID []byte) ([]byte, error) { if len(contentID) < s.nonceSize { return nil, errors.Errorf("hash too short, expected >=%v bytes, got %v", s.nonceSize, len(contentID)) } - result := make([]byte, len(input)) + result, out := sliceForAppend(output, len(input)) nonce := contentID[0:s.nonceSize] - salsa20.XORKeyStream(result, input, nonce, s.key) + salsa20.XORKeyStream(out, input, nonce, s.key) return result, nil } diff --git a/repo/encryption/encryption.go b/repo/encryption/encryption.go index 4375c8568..076a7a3d4 100644 --- a/repo/encryption/encryption.go +++ b/repo/encryption/encryption.go @@ -14,14 +14,14 @@ // Encryptor performs encryption and decryption of contents of data. type Encryptor interface { - // Encrypt returns encrypted bytes corresponding to the given plaintext. + // Encrypt appends the encrypted bytes corresponding to the given plaintext to a given slice. // Must not clobber the input slice and return ciphertext with additional padding and checksum. - Encrypt(plainText, contentID []byte) ([]byte, error) + Encrypt(output, plainText, contentID []byte) ([]byte, error) - // Decrypt returns unencrypted bytes corresponding to the given ciphertext. + // Decrypt appends the unencrypted bytes corresponding to the given ciphertext to a given slice. // Must not clobber the input slice. If IsAuthenticated() == true, Decrypt will perform // authenticity check before decrypting. - Decrypt(cipherText, contentID []byte) ([]byte, error) + Decrypt(output, cipherText, contentID []byte) ([]byte, error) // IsAuthenticated returns true if encryption is authenticated. // In this case Decrypt() is expected to perform authenticity check. @@ -91,10 +91,6 @@ type encryptorInfo struct { var encryptors = map[string]*encryptorInfo{} -func cloneBytes(b []byte) []byte { - return append([]byte{}, b...) -} - // deriveKey uses HKDF to derive a key of a given length and a given purpose from parameters. // nolint:unparam func deriveKey(p Parameters, purpose []byte, length int) ([]byte, error) { @@ -108,3 +104,23 @@ func deriveKey(p Parameters, purpose []byte, length int) ([]byte, error) { return key, nil } + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +// +// From: https://golang.org/src/crypto/cipher/gcm.go +// Copyright 2013 The Go Authors. All rights reserved. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + + tail = head[len(in):] + + return +} diff --git a/repo/encryption/encryption_test.go b/repo/encryption/encryption_test.go index aa98439d1..23806ce53 100644 --- a/repo/encryption/encryption_test.go +++ b/repo/encryption/encryption_test.go @@ -40,13 +40,13 @@ func TestRoundTrip(t *testing.T) { t.Fatal(err) } - cipherText1, err := e.Encrypt(data, contentID1) + cipherText1, err := e.Encrypt(nil, data, contentID1) if err != nil || cipherText1 == nil { t.Errorf("invalid response from Encrypt: %v %v", cipherText1, err) } if !e.IsDeprecated() && encryptionAlgo != encryption.NoneAlgorithm { - cipherText1b, err2 := e.Encrypt(data, contentID1) + cipherText1b, err2 := e.Encrypt(nil, data, contentID1) if err2 != nil || cipherText1b == nil { t.Errorf("invalid response from Encrypt: %v %v", cipherText1, err2) } @@ -56,7 +56,7 @@ func TestRoundTrip(t *testing.T) { } } - plainText1, err := e.Decrypt(cipherText1, contentID1) + plainText1, err := e.Decrypt(nil, cipherText1, contentID1) if err != nil || plainText1 == nil { t.Errorf("invalid response from Decrypt: %v %v", plainText1, err) } @@ -65,12 +65,23 @@ func TestRoundTrip(t *testing.T) { t.Errorf("Encrypt()/Decrypt() does not round-trip: %x %x", plainText1, data) } - cipherText2, err := e.Encrypt(data, contentID2) + plaintextOutput := make([]byte, 0, 256) + + plainText1a, err := e.Decrypt(plaintextOutput, cipherText1, contentID1) + if err != nil || plainText1 == nil { + t.Errorf("invalid response from Decrypt: %v %v", plainText1, err) + } + + if !bytes.Equal(plainText1a, plaintextOutput[0:len(plainText1a)]) { + t.Errorf("Decrypt() does not use output buffer") + } + + cipherText2, err := e.Encrypt(nil, data, contentID2) if err != nil || cipherText2 == nil { t.Errorf("invalid response from Encrypt: %v %v", cipherText2, err) } - plainText2, err := e.Decrypt(cipherText2, contentID2) + plainText2, err := e.Decrypt(nil, cipherText2, contentID2) if err != nil || plainText2 == nil { t.Errorf("invalid response from Decrypt: %v %v", plainText2, err) } @@ -85,7 +96,7 @@ func TestRoundTrip(t *testing.T) { } // decrypt using wrong content ID - badPlainText2, err := e.Decrypt(cipherText2, contentID1) + badPlainText2, err := e.Decrypt(nil, cipherText2, contentID1) if e.IsAuthenticated() { if err == nil && encryptionAlgo != "SALSA20-HMAC" { // "SALSA20-HMAC" is deprecated & wrong, and only validates that checksum is @@ -102,7 +113,7 @@ func TestRoundTrip(t *testing.T) { // flip some bits in the cipherText if e.IsAuthenticated() { cipherText2[mathrand.Intn(len(cipherText2))] ^= byte(1 + mathrand.Intn(254)) - if _, err := e.Decrypt(cipherText2, contentID1); err == nil { + if _, err := e.Decrypt(nil, cipherText2, contentID1); err == nil { t.Errorf("expected decrypt failure on invalid ciphertext, got success") } } @@ -173,7 +184,7 @@ func verifyCiphertextSamples(t *testing.T, masterKey, contentID, payload []byte, ct := samples[encryptionAlgo] if ct == "" { - v, err := enc.Encrypt(payload, contentID) + v, err := enc.Encrypt(nil, payload, contentID) if err != nil { t.Fatal(err) } @@ -186,7 +197,7 @@ func verifyCiphertextSamples(t *testing.T, masterKey, contentID, payload []byte, continue } - plainText, err := enc.Decrypt(b, contentID) + plainText, err := enc.Decrypt(nil, b, contentID) if err != nil { t.Errorf("unable to decrypt %v: %v", encryptionAlgo, err) continue diff --git a/repo/encryption/null_encryptor.go b/repo/encryption/null_encryptor.go index ef91c4bfa..8982c14d0 100644 --- a/repo/encryption/null_encryptor.go +++ b/repo/encryption/null_encryptor.go @@ -4,12 +4,12 @@ type nullEncryptor struct { } -func (fi nullEncryptor) Encrypt(plainText, contentID []byte) ([]byte, error) { - return cloneBytes(plainText), nil +func (fi nullEncryptor) Encrypt(output, plainText, contentID []byte) ([]byte, error) { + return append(output, plainText...), nil } -func (fi nullEncryptor) Decrypt(cipherText, contentID []byte) ([]byte, error) { - return cloneBytes(cipherText), nil +func (fi nullEncryptor) Decrypt(output, cipherText, contentID []byte) ([]byte, error) { + return append(output, cipherText...), nil } func (fi nullEncryptor) IsAuthenticated() bool { diff --git a/repo/hashing/hashing.go b/repo/hashing/hashing.go index f4a0894ce..66ced5b00 100644 --- a/repo/hashing/hashing.go +++ b/repo/hashing/hashing.go @@ -5,6 +5,7 @@ "crypto/hmac" "hash" "sort" + "sync" "github.com/pkg/errors" ) @@ -16,7 +17,7 @@ type Parameters interface { } // HashFunc computes hash of content of data using a cryptographic hash function, possibly with HMAC and/or truncation. -type HashFunc func(data []byte) []byte +type HashFunc func(output, data []byte) []byte // HashFuncFactory returns a hash function for given formatting options. type HashFuncFactory func(p Parameters) (HashFunc, error) @@ -47,11 +48,20 @@ func SupportedAlgorithms() []string { // and truncates results to the given size. func truncatedHMACHashFuncFactory(hf func() hash.Hash, truncate int) HashFuncFactory { return func(p Parameters) (HashFunc, error) { - return func(b []byte) []byte { - h := hmac.New(hf, p.GetHMACSecret()) + pool := sync.Pool{ + New: func() interface{} { + return hmac.New(hf, p.GetHMACSecret()) + }, + } + + return func(output, b []byte) []byte { + h := pool.Get().(hash.Hash) + defer pool.Put(h) + + h.Reset() h.Write(b) // nolint:errcheck - return h.Sum(nil)[0:truncate] + return h.Sum(output)[0:truncate] }, nil } } @@ -60,15 +70,26 @@ func truncatedHMACHashFuncFactory(hf func() hash.Hash, truncate int) HashFuncFac // and truncates results to the given size. func truncatedKeyedHashFuncFactory(hf func(key []byte) (hash.Hash, error), truncate int) HashFuncFactory { return func(p Parameters) (HashFunc, error) { - if _, err := hf(p.GetHMACSecret()); err != nil { + secret := p.GetHMACSecret() + if _, err := hf(secret); err != nil { return nil, err } - return func(b []byte) []byte { - h, _ := hf(p.GetHMACSecret()) + pool := sync.Pool{ + New: func() interface{} { + h, _ := hf(secret) + return h + }, + } + + return func(output, b []byte) []byte { + h := pool.Get().(hash.Hash) + defer pool.Put(h) + + h.Reset() h.Write(b) // nolint:errcheck - return h.Sum(nil)[0:truncate] + return h.Sum(output)[0:truncate] }, nil } } diff --git a/repo/hashing/hashing_test.go b/repo/hashing/hashing_test.go index d563c1743..cd93bec32 100644 --- a/repo/hashing/hashing_test.go +++ b/repo/hashing/hashing_test.go @@ -36,14 +36,19 @@ func TestRoundTrip(t *testing.T) { t.Fatal(err) } - hash1a := f(data1) - hash1b := f(data1) - hash2 := f(data2) + outputBuffer := make([]byte, 0, 256) + hash1a := f(nil, data1) + hash1b := f(outputBuffer, data1) + hash2 := f(nil, data2) if !bytes.Equal(hash1a, hash1b) { t.Fatalf("hashing not stable: %x %x", hash1a, hash1b) } + if !bytes.Equal(hash1a, outputBuffer[0:len(hash1a)]) { + t.Fatalf("hash did not populate output buffer") + } + if bytes.Equal(hash1a, hash2) { t.Fatalf("hashing should produce different results: %x", hash1a) } diff --git a/tests/stress_test/stress_test.go b/tests/stress_test/stress_test.go index c51784214..8cd924129 100644 --- a/tests/stress_test/stress_test.go +++ b/tests/stress_test/stress_test.go @@ -1,11 +1,11 @@ package stress_test import ( + "bytes" "context" "fmt" "math/rand" "os" - "reflect" "testing" "time" @@ -127,7 +127,7 @@ type writtenBlock struct { return } - if !reflect.DeepEqual(previous.data, d2) { + if !bytes.Equal(previous.data, d2) { t.Errorf("invalid previous data for %q %x %x", previous.contentID, d2, previous.data) return }