diff --git a/cli/app.go b/cli/app.go index 18c710cf8..8ca4a1b8d 100644 --- a/cli/app.go +++ b/cli/app.go @@ -89,7 +89,7 @@ type appServices interface { stdout() io.Writer Stderr() io.Writer stdin() io.Reader - onCtrlC(callback func()) + onTerminate(callback func()) onRepositoryFatalError(callback func(err error)) enableTestOnlyFlags() bool EnvName(s string) string diff --git a/cli/command_mount.go b/cli/command_mount.go index d6e236623..bd1a82dc3 100644 --- a/cli/command_mount.go +++ b/cli/command_mount.go @@ -103,7 +103,7 @@ func (c *commandMount) run(ctx context.Context, rep repo.Repository) error { // Wait until ctrl-c pressed or until the directory is unmounted. ctrlCPressed := make(chan bool) - c.svc.onCtrlC(func() { + c.svc.onTerminate(func() { close(ctrlCPressed) }) diff --git a/cli/command_repository_upgrade.go b/cli/command_repository_upgrade.go index fe1a3830a..546903dc4 100644 --- a/cli/command_repository_upgrade.go +++ b/cli/command_repository_upgrade.go @@ -392,7 +392,7 @@ func (c *commandRepositoryUpgrade) sleepWithContext(ctx context.Context, dur tim stop := make(chan struct{}) - c.svc.onCtrlC(func() { close(stop) }) + c.svc.onTerminate(func() { close(stop) }) select { case <-ctx.Done(): diff --git a/cli/command_server_start.go b/cli/command_server_start.go index 10587831f..2e7f1db00 100644 --- a/cli/command_server_start.go +++ b/cli/command_server_start.go @@ -222,7 +222,7 @@ func (c *commandServerStart) run(ctx context.Context) error { return nil } - c.svc.onCtrlC(func() { + c.svc.onTerminate(func() { log(ctx).Infof("Shutting down...") if serr := httpServer.Shutdown(ctx); serr != nil { diff --git a/cli/command_snapshot_create.go b/cli/command_snapshot_create.go index 395a250a4..c998ed87c 100644 --- a/cli/command_snapshot_create.go +++ b/cli/command_snapshot_create.go @@ -233,7 +233,7 @@ func (c *commandSnapshotCreate) setupUploader(rep repo.RepositoryWriter) *snapsh u.CheckpointInterval = interval } - c.svc.onCtrlC(u.Cancel) + c.svc.onTerminate(u.Cancel) u.ForceHashPercentage = c.snapshotCreateForceHash u.ParallelUploads = c.snapshotCreateParallelUploads diff --git a/cli/command_snapshot_migrate.go b/cli/command_snapshot_migrate.go index bd4c9e4f3..724ff7b3b 100644 --- a/cli/command_snapshot_migrate.go +++ b/cli/command_snapshot_migrate.go @@ -68,7 +68,7 @@ func (c *commandSnapshotMigrate) run(ctx context.Context, destRepo repo.Reposito c.svc.getProgress().StartShared() - c.svc.onCtrlC(func() { + c.svc.onTerminate(func() { mu.Lock() defer mu.Unlock() diff --git a/cli/config.go b/cli/config.go index 1278ef2a4..3daac638a 100644 --- a/cli/config.go +++ b/cli/config.go @@ -8,6 +8,7 @@ "os/signal" "path/filepath" "runtime" + "syscall" "github.com/alecthomas/kingpin/v2" "github.com/pkg/errors" @@ -29,9 +30,9 @@ func (c *App) onRepositoryFatalError(f func(err error)) { c.onFatalErrorCallbacks = append(c.onFatalErrorCallbacks, f) } -func (c *App) onCtrlC(f func()) { +func (c *App) onTerminate(f func()) { s := make(chan os.Signal, 1) - signal.Notify(s, os.Interrupt) + signal.Notify(s, os.Interrupt, syscall.SIGTERM) go func() { // invoke the function when either real or simulated Ctrl-C signal is delivered diff --git a/cli/inproc.go b/cli/inproc.go index f625c47d9..204a4e022 100644 --- a/cli/inproc.go +++ b/cli/inproc.go @@ -3,6 +3,7 @@ import ( "context" "io" + "os" "github.com/alecthomas/kingpin/v2" @@ -12,7 +13,7 @@ // RunSubcommand executes the subcommand asynchronously in current process // with flags in an isolated CLI environment and returns standard output and standard error. -func (c *App) RunSubcommand(ctx context.Context, kpapp *kingpin.Application, stdin io.Reader, argsAndFlags []string) (stdout, stderr io.Reader, wait func() error, kill func()) { +func (c *App) RunSubcommand(ctx context.Context, kpapp *kingpin.Application, stdin io.Reader, argsAndFlags []string) (stdout, stderr io.Reader, wait func() error, interrupt func(os.Signal)) { stdoutReader, stdoutWriter := io.Pipe() stderrReader, stderrWriter := io.Pipe() @@ -59,7 +60,7 @@ func (c *App) RunSubcommand(ctx context.Context, kpapp *kingpin.Application, std return stdoutReader, stderrReader, func() error { return <-resultErr - }, func() { + }, func(_ os.Signal) { // deliver simulated Ctrl-C to the app. c.simulatedCtrlC <- true } diff --git a/cli/terminate_signal_test.go b/cli/terminate_signal_test.go new file mode 100644 index 000000000..4813a0ea2 --- /dev/null +++ b/cli/terminate_signal_test.go @@ -0,0 +1,30 @@ +package cli_test + +import ( + "strings" + "syscall" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/kopia/kopia/tests/testenv" +) + +// Waits until the server advertises its address on the line. +func serverStarted(line string) bool { + return !strings.HasPrefix(line, "SERVER ADDRESS: ") +} + +func TestTerminate(t *testing.T) { + env := testenv.NewCLITest(t, testenv.RepoFormatNotImportant, testenv.NewExeRunner(t)) + + env.RunAndExpectSuccess(t, "repo", "create", "filesystem", "--path", env.RepoDir) + + wait, interrupt := env.RunAndProcessStderrInt(t, serverStarted, "server", "start", + "--address=localhost:0", + "--insecure") + + interrupt(syscall.SIGTERM) + + require.NoError(t, wait()) +} diff --git a/tests/testenv/cli_exe_runner.go b/tests/testenv/cli_exe_runner.go index 91ec037e8..39ee62e20 100644 --- a/tests/testenv/cli_exe_runner.go +++ b/tests/testenv/cli_exe_runner.go @@ -20,7 +20,7 @@ type CLIExeRunner struct { } // Start implements CLIRunner. -func (e *CLIExeRunner) Start(t *testing.T, args []string, env map[string]string) (stdout, stderr io.Reader, wait func() error, kill func()) { +func (e *CLIExeRunner) Start(t *testing.T, args []string, env map[string]string) (stdout, stderr io.Reader, wait func() error, interrupt func(os.Signal)) { t.Helper() c := exec.Command(e.Exe, append([]string{ @@ -51,8 +51,14 @@ func (e *CLIExeRunner) Start(t *testing.T, args []string, env map[string]string) t.Fatalf("unable to start: %v", err) } - return stdoutPipe, stderrPipe, c.Wait, func() { - c.Process.Kill() + return stdoutPipe, stderrPipe, c.Wait, func(sig os.Signal) { + if sig == os.Kill { + c.Process.Kill() + + return + } + + c.Process.Signal(sig) } } diff --git a/tests/testenv/cli_inproc_runner.go b/tests/testenv/cli_inproc_runner.go index f5afae331..617a89146 100644 --- a/tests/testenv/cli_inproc_runner.go +++ b/tests/testenv/cli_inproc_runner.go @@ -27,7 +27,7 @@ type CLIInProcRunner struct { } // Start implements CLIRunner. -func (e *CLIInProcRunner) Start(t *testing.T, args []string, env map[string]string) (stdout, stderr io.Reader, wait func() error, kill func()) { +func (e *CLIInProcRunner) Start(t *testing.T, args []string, env map[string]string) (stdout, stderr io.Reader, wait func() error, interrupt func(os.Signal)) { t.Helper() ctx := testlogging.Context(t) diff --git a/tests/testenv/cli_test_env.go b/tests/testenv/cli_test_env.go index df001eeee..79849218b 100644 --- a/tests/testenv/cli_test_env.go +++ b/tests/testenv/cli_test_env.go @@ -30,7 +30,7 @@ // CLIRunner encapsulates running kopia subcommands for testing purposes. // It supports implementations that use subprocesses or in-process invocations. type CLIRunner interface { - Start(t *testing.T, args []string, env map[string]string) (stdout, stderr io.Reader, wait func() error, kill func()) + Start(t *testing.T, args []string, env map[string]string) (stdout, stderr io.Reader, wait func() error, interrupt func(os.Signal)) } // CLITest encapsulates state for a CLI-based test. @@ -165,7 +165,20 @@ func (e *CLITest) getLogOutputPrefix() (string, bool) { func (e *CLITest) RunAndProcessStderr(t *testing.T, callback func(line string) bool, args ...string) (wait func() error, kill func()) { t.Helper() - stdout, stderr, wait, kill := e.Runner.Start(t, e.cmdArgs(args), e.Environment) + wait, interrupt := e.RunAndProcessStderrInt(t, callback, args...) + kill = func() { + interrupt(os.Kill) + } + + return wait, kill +} + +// RunAndProcessStderrInt runs the given command, and streams its output +// line-by-line to outputCallback until it returns false. +func (e *CLITest) RunAndProcessStderrInt(t *testing.T, outputCallback func(line string) bool, args ...string) (wait func() error, interrupt func(os.Signal)) { + t.Helper() + + stdout, stderr, wait, interrupt := e.Runner.Start(t, e.cmdArgs(args), e.Environment) go func() { scanner := bufio.NewScanner(stdout) @@ -182,7 +195,7 @@ func (e *CLITest) RunAndProcessStderr(t *testing.T, callback func(line string) b scanner := bufio.NewScanner(stderr) for scanner.Scan() { - if !callback(scanner.Text()) { + if !outputCallback(scanner.Text()) { break } } @@ -200,7 +213,7 @@ func (e *CLITest) RunAndProcessStderr(t *testing.T, callback func(line string) b } }() - return wait, kill + return wait, interrupt } // RunAndExpectSuccessWithErrOut runs the given command, expects it to succeed and returns its stdout and stderr lines.