diff --git a/snapshot/policy/compression_policy.go b/snapshot/policy/compression_policy.go index 4b9e93747..4c7822720 100644 --- a/snapshot/policy/compression_policy.go +++ b/snapshot/policy/compression_policy.go @@ -30,7 +30,7 @@ type CompressionPolicyDefinition struct { } // CompressorForFile returns compression name to be used for compressing a given file according to policy, using attributes such as name or size. -func (p *CompressionPolicy) CompressorForFile(e fs.File) compression.Name { +func (p *CompressionPolicy) CompressorForFile(e fs.Entry) compression.Name { ext := filepath.Ext(e.Name()) size := e.Size() diff --git a/snapshot/snapshotfs/upload.go b/snapshot/snapshotfs/upload.go index d615eb003..f626a670a 100644 --- a/snapshot/snapshotfs/upload.go +++ b/snapshot/snapshotfs/upload.go @@ -336,7 +336,7 @@ func (u *Uploader) uploadSymlinkInternal(ctx context.Context, relativePath strin return de, nil } -func (u *Uploader) uploadStreamingFileInternal(ctx context.Context, relativePath string, f fs.StreamingFile) (dirEntry *snapshot.DirEntry, ret error) { +func (u *Uploader) uploadStreamingFileInternal(ctx context.Context, relativePath string, f fs.StreamingFile, pol *policy.Policy) (dirEntry *snapshot.DirEntry, ret error) { reader, err := f.GetReader(ctx) if err != nil { return nil, errors.Wrap(err, "unable to get streaming file reader") @@ -353,9 +353,12 @@ func (u *Uploader) uploadStreamingFileInternal(ctx context.Context, relativePath u.Progress.FinishedFile(relativePath, ret) }() + comp := pol.CompressionPolicy.CompressorForFile(f) writer := u.repo.NewObjectWriter(ctx, object.WriterOptions{ Description: "STREAMFILE:" + f.Name(), + Compressor: comp, }) + defer writer.Close() //nolint:errcheck written, err := u.copyWithProgress(writer, reader) @@ -919,7 +922,7 @@ func (u *Uploader) processSingle( case fs.StreamingFile: atomic.AddInt32(&u.stats.NonCachedFiles, 1) - de, err := u.uploadStreamingFileInternal(ctx, entryRelativePath, entry) + de, err := u.uploadStreamingFileInternal(ctx, entryRelativePath, entry, policyTree.Child(entry.Name()).EffectivePolicy()) return u.processEntryUploadResult(ctx, de, err, entryRelativePath, parentDirBuilder, policyTree.EffectivePolicy().ErrorHandlingPolicy.IgnoreFileErrors.OrDefault(false), diff --git a/snapshot/snapshotfs/upload_test.go b/snapshot/snapshotfs/upload_test.go index fbb026c82..01a262b4f 100644 --- a/snapshot/snapshotfs/upload_test.go +++ b/snapshot/snapshotfs/upload_test.go @@ -873,6 +873,42 @@ func TestUpload_VirtualDirectoryWithStreamingFile(t *testing.T) { } } +func TestUpload_VirtualDirectoryWithStreamingFile_WithCompression(t *testing.T) { + ctx := testlogging.Context(t) + th := newUploadTestHarness(ctx, t) + + defer th.cleanup() + + u := NewUploader(th.repo) + + pol := *policy.DefaultPolicy + pol.CompressionPolicy.CompressorName = "pgzip" + + policyTree := policy.BuildTree(nil, &pol) + + // Create a temporary file with test data. Want something compressible but + // small so we don't trigger dedupe. + content := []byte(strings.Repeat("a", 4096)) + r := io.NopCloser(bytes.NewReader(content)) + + staticRoot := virtualfs.NewStaticDirectory("rootdir", []fs.Entry{ + virtualfs.StreamingFileFromReader("stream-file", r), + }) + + man, err := u.Upload(ctx, staticRoot, policyTree, snapshot.SourceInfo{}) + require.NoError(t, err) + + assert.Equal(t, int32(0), atomic.LoadInt32(&man.Stats.CachedFiles), "cached file count") + assert.Equal(t, int32(1), atomic.LoadInt32(&man.Stats.NonCachedFiles), "non-cached file count") + assert.Equal(t, int32(1), atomic.LoadInt32(&man.Stats.TotalDirectoryCount), "directory count") + assert.Equal(t, int32(1), atomic.LoadInt32(&man.Stats.TotalFileCount), "total file count") + + // Write out pending data so the below size check compares properly. + require.NoError(t, th.repo.Flush(ctx), "flushing repo") + + assert.Less(t, testutil.MustGetTotalDirSize(t, th.repoDir), int64(14000)) +} + func TestUpload_VirtualDirectoryWithStreamingFileWithModTime(t *testing.T) { content := []byte("Streaming Temporary file content") mt := time.Date(2021, 1, 2, 3, 4, 5, 0, time.UTC)