From 851fe6c1937ce30a4835db322d8b34e1946731b7 Mon Sep 17 00:00:00 2001 From: Maxim Schuwalow <16665913+mschuwalow@users.noreply.github.com> Date: Wed, 12 Jul 2023 02:34:30 +0200 Subject: [PATCH] feat(cli): add --override-source parameter to allow overriding source when snapshotting (#3041) --- cli/command_snapshot_create.go | 98 +++++++++++++------ tests/end_to_end_test/snapshot_create_test.go | 55 +++++++++++ 2 files changed, 122 insertions(+), 31 deletions(-) diff --git a/cli/command_snapshot_create.go b/cli/command_snapshot_create.go index e0e872123..3454336d5 100644 --- a/cli/command_snapshot_create.go +++ b/cli/command_snapshot_create.go @@ -39,6 +39,7 @@ type commandSnapshotCreate struct { snapshotCreateCheckpointUploadLimitMB int64 snapshotCreateTags []string flushPerSource bool + sourceOverride string pins []string @@ -67,8 +68,9 @@ func (c *commandSnapshotCreate) setup(svc appServices, parent commandParent) { cmd.Flag("force-disable-actions", "Disable snapshot actions even if globally enabled on this client").Hidden().BoolVar(&c.snapshotCreateForceDisableActions) cmd.Flag("stdin-file", "File path to be used for stdin data snapshot.").StringVar(&c.snapshotCreateStdinFileName) cmd.Flag("tags", "Tags applied on the snapshot. Must be provided in the : format.").StringsVar(&c.snapshotCreateTags) - cmd.Flag("pin", "Create a pinned snapshot that's will not expire automatically").StringsVar(&c.pins) + cmd.Flag("pin", "Create a pinned snapshot that will not expire automatically").StringsVar(&c.pins) cmd.Flag("flush-per-source", "Flush writes at the end of each source").Hidden().BoolVar(&c.flushPerSource) + cmd.Flag("override-source", "Override the source of the snapshot.").StringVar(&c.sourceOverride) c.logDirDetail = -1 c.logEntryDetail = -1 @@ -127,19 +129,12 @@ func (c *commandSnapshotCreate) run(ctx context.Context, rep repo.RepositoryWrit break } - dir, err := filepath.Abs(snapshotDir) + fsEntry, sourceInfo, setManual, err := c.getContentToSnapshot(ctx, snapshotDir, rep) if err != nil { - finalErrors = append(finalErrors, fmt.Sprintf("invalid source: '%s': %s", snapshotDir, err)) - continue + finalErrors = append(finalErrors, fmt.Sprintf("failed to prepare source: %s", err)) } - sourceInfo := snapshot.SourceInfo{ - Path: filepath.Clean(dir), - Host: rep.ClientOptions().Hostname, - UserName: rep.ClientOptions().Username, - } - - if err := c.snapshotSingleSource(ctx, rep, u, sourceInfo, tags); err != nil { + if err := c.snapshotSingleSource(ctx, fsEntry, setManual, rep, u, sourceInfo, tags); err != nil { finalErrors = append(finalErrors, err.Error()) } } @@ -261,28 +256,10 @@ func startTimeAfterEndTime(startTime, endTime time.Time) bool { } //nolint:gocyclo -func (c *commandSnapshotCreate) snapshotSingleSource(ctx context.Context, rep repo.RepositoryWriter, u *snapshotfs.Uploader, sourceInfo snapshot.SourceInfo, tags map[string]string) error { +func (c *commandSnapshotCreate) snapshotSingleSource(ctx context.Context, fsEntry fs.Entry, setManual bool, rep repo.RepositoryWriter, u *snapshotfs.Uploader, sourceInfo snapshot.SourceInfo, tags map[string]string) error { log(ctx).Infof("Snapshotting %v ...", sourceInfo) - var ( - err error - fsEntry fs.Entry - setManual bool - ) - - if c.snapshotCreateStdinFileName != "" { - // stdin source will be snapshotted using a virtual static root directory with a single streaming file entry - // Create a new static directory with the given name and add a streaming file entry with os.Stdin reader - fsEntry = virtualfs.NewStaticDirectory(sourceInfo.Path, []fs.Entry{ - virtualfs.StreamingFileFromReader(c.snapshotCreateStdinFileName, io.NopCloser(c.svc.stdin())), - }) - setManual = true - } else { - fsEntry, err = getLocalFSEntry(ctx, sourceInfo.Path) - if err != nil { - return errors.Wrap(err, "unable to get local filesystem entry") - } - } + var err error previous, err := findPreviousSnapshotManifest(ctx, rep, sourceInfo, nil) if err != nil { @@ -470,3 +447,62 @@ func shouldSnapshotSource(ctx context.Context, src snapshot.SourceInfo, rep repo src.UserName == rep.ClientOptions().Username && !policy.IsManualSnapshot(policyTree), nil } + +func (c *commandSnapshotCreate) getContentToSnapshot(ctx context.Context, dir string, rep repo.RepositoryWriter) (fs.Entry, snapshot.SourceInfo, bool, error) { + var ( + absDir string + sourceInfo snapshot.SourceInfo + fsEntry fs.Entry + setManual bool + err error + ) + + absDir, err = filepath.Abs(dir) + if err != nil { + return nil, sourceInfo, false, errors.Wrapf(err, "invalid source %v", dir) + } + + if c.sourceOverride != "" { + sourceInfo, err = parseFullSource(c.sourceOverride, rep) + + if err != nil { + return nil, sourceInfo, false, errors.Wrapf(err, "invalid source override %v", c.sourceOverride) + } + + setManual = true + } else { + sourceInfo = snapshot.SourceInfo{ + Path: filepath.Clean(absDir), + Host: rep.ClientOptions().Hostname, + UserName: rep.ClientOptions().Username, + } + } + + if c.snapshotCreateStdinFileName != "" { + // stdin source will be snapshotted using a virtual static root directory with a single streaming file entry + // Create a new static directory with the given name and add a streaming file entry with os.Stdin reader + fsEntry = virtualfs.NewStaticDirectory(absDir, []fs.Entry{ + virtualfs.StreamingFileFromReader(c.snapshotCreateStdinFileName, io.NopCloser(c.svc.stdin())), + }) + setManual = true + } else { + fsEntry, err = getLocalFSEntry(ctx, absDir) + if err != nil { + return nil, sourceInfo, false, errors.Wrap(err, "unable to get local filesystem entry") + } + } + + return fsEntry, sourceInfo, setManual, nil +} + +func parseFullSource(str string, rep repo.RepositoryWriter) (snapshot.SourceInfo, error) { + sourceInfo, err := snapshot.ParseSourceInfo(str, rep.ClientOptions().Hostname, rep.ClientOptions().Username) + + if err != nil { + return snapshot.SourceInfo{}, errors.Wrapf(err, "not a valid source %v", str) + } else if sourceInfo.Host == "" || sourceInfo.UserName == "" || sourceInfo.Path == "" { + return snapshot.SourceInfo{}, errors.Errorf("source does not resolve into host, user and path: '%s'", str) + } + + return sourceInfo, nil +} diff --git a/tests/end_to_end_test/snapshot_create_test.go b/tests/end_to_end_test/snapshot_create_test.go index 3dbb9042c..a8c4539dd 100644 --- a/tests/end_to_end_test/snapshot_create_test.go +++ b/tests/end_to_end_test/snapshot_create_test.go @@ -5,6 +5,8 @@ "path" "path/filepath" "reflect" + "regexp" + "runtime" "sort" "strings" "testing" @@ -18,6 +20,7 @@ "github.com/kopia/kopia/internal/cachedir" "github.com/kopia/kopia/internal/testutil" "github.com/kopia/kopia/snapshot" + "github.com/kopia/kopia/snapshot/policy" "github.com/kopia/kopia/tests/clitestutil" "github.com/kopia/kopia/tests/testenv" ) @@ -728,3 +731,55 @@ func TestSnapshotCreateAllFlushPerSource(t *testing.T) { require.Len(t, indexList3, len(indexList2)+3) require.Len(t, metadataBlobList3, len(metadataBlobList2)+3) } + +func TestSnapshotCreateAllSnapshotPath(t *testing.T) { + t.Parallel() + + runner := testenv.NewInProcRunner(t) + e := testenv.NewCLITest(t, testenv.RepoFormatNotImportant, runner) + + defer e.RunAndExpectSuccess(t, "repo", "disconnect") + + e.RunAndExpectSuccess(t, "repo", "create", "filesystem", "--path", e.RepoDir, "--override-hostname=foo", "--override-username=foo") + e.RunAndExpectSuccess(t, "snapshot", "create", "--override-source", "bar@bar:/foo/bar", sharedTestDataDir1) + e.RunAndExpectSuccess(t, "snapshot", "create", "--override-source", "bar@bar:C:\\foo\\baz", sharedTestDataDir2) + e.RunAndExpectSuccess(t, "snapshot", "create", "--override-source", "/foo/bar", sharedTestDataDir3) + + // Make sure the scheduling policy with manual field is set and visible in the policy list, includes global policy + var plist []policy.TargetWithPolicy + + testutil.MustParseJSONLines(t, e.RunAndExpectSuccess(t, "policy", "list", "--json"), &plist) + + if got, want := len(plist), 4; got != want { + t.Fatalf("got %v policies, wanted %v", got, want) + } + + // all non-global policies should be manual + for _, p := range plist { + if (p.Target != snapshot.SourceInfo{}) { + require.True(t, p.Policy.SchedulingPolicy.Manual) + } + } + + si := clitestutil.ListSnapshotsAndExpectSuccess(t, e, "--all") + if got, want := len(si), 3; got != want { + t.Fatalf("got %v sources, wanted %v", got, want) + } + + require.Equal(t, "bar", si[0].User) + require.Equal(t, "bar", si[0].Host) + require.Equal(t, "/foo/bar", si[0].Path) + + require.Equal(t, "bar", si[1].User) + require.Equal(t, "bar", si[1].Host) + require.Equal(t, "C:\\foo\\baz", si[1].Path) + + require.Equal(t, "foo", si[2].User) + require.Equal(t, "foo", si[2].Host) + + if runtime.GOOS == "windows" { + require.Regexp(t, regexp.MustCompile(`[A-Z]:\\foo\\bar`), si[2].Path) + } else { + require.Equal(t, "/foo/bar", si[2].Path) + } +}