diff --git a/cli/app.go b/cli/app.go index 72e29d5fb..b8a1b411b 100644 --- a/cli/app.go +++ b/cli/app.go @@ -12,6 +12,7 @@ "github.com/alecthomas/kingpin" "github.com/fatih/color" + "github.com/gorilla/mux" "github.com/mattn/go-colorable" "github.com/pkg/errors" @@ -425,21 +426,21 @@ func (c *App) maybeRepositoryAction(act func(ctx context.Context, rep repo.Repos defer gather.DumpStats(ctx) if c.metricsListenAddr != "" { - mux := http.NewServeMux() - if err := initPrometheus(mux); err != nil { + m := mux.NewRouter() + if err := initPrometheus(m); err != nil { return errors.Wrap(err, "unable to initialize prometheus.") } if c.enablePProf { - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + m.HandleFunc("/debug/pprof/", pprof.Index) + m.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + m.HandleFunc("/debug/pprof/profile", pprof.Profile) + m.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + m.HandleFunc("/debug/pprof/trace", pprof.Trace) } log(ctx).Infof("starting prometheus metrics on %v", c.metricsListenAddr) - go http.ListenAndServe(c.metricsListenAddr, mux) // nolint:errcheck + go http.ListenAndServe(c.metricsListenAddr, m) // nolint:errcheck } memtrack.Dump(ctx, "before openRepository") diff --git a/cli/command_server_start.go b/cli/command_server_start.go index b71a2b79e..3670c00d6 100644 --- a/cli/command_server_start.go +++ b/cli/command_server_start.go @@ -14,6 +14,7 @@ "time" "contrib.go.opencensus.io/exporter/prometheus" + "github.com/gorilla/mux" "github.com/pkg/errors" prom "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" @@ -30,10 +31,11 @@ type commandServerStart struct { co connectOptions serverStartHTMLPath string - serverStartUI bool + serverStartUI bool serverStartLegacyRepositoryAPI bool serverStartGRPC bool + serverStartControlAPI bool serverStartRefreshInterval time.Duration serverStartInsecure bool @@ -63,6 +65,8 @@ type commandServerStart struct { logServerRequests bool + disableCSRFTokenChecks bool // disable CSRF token checks - used for development/debugging only + sf serverFlags svc advancedAppServices out textOutput @@ -75,6 +79,7 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent cmd.Flag("legacy-api", "Start the legacy server API").Default("true").BoolVar(&c.serverStartLegacyRepositoryAPI) cmd.Flag("grpc", "Start the GRPC server").Default("true").BoolVar(&c.serverStartGRPC) + cmd.Flag("control-api", "Start the control API").Default("true").BoolVar(&c.serverStartControlAPI) cmd.Flag("refresh-interval", "Frequency for refreshing repository status").Default("300s").DurationVar(&c.serverStartRefreshInterval) cmd.Flag("insecure", "Allow insecure configurations (do not use in production)").Hidden().BoolVar(&c.serverStartInsecure) @@ -104,6 +109,7 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent cmd.Flag("ui-preferences-file", "Path to JSON file storing UI preferences").StringVar(&c.uiPreferencesFile) cmd.Flag("log-server-requests", "Log server requests").Hidden().BoolVar(&c.logServerRequests) + cmd.Flag("disable-csrf-token-checks", "Disable CSRF token").Hidden().BoolVar(&c.disableCSRFTokenChecks) c.sf.setup(cmd) c.co.setup(cmd) @@ -116,7 +122,6 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent })) } -// nolint:funlen func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error { authn, err := c.getAuthenticator(ctx) if err != nil { @@ -142,6 +147,8 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error PasswordPersist: c.svc.passwordPersistenceStrategy(), UIPreferencesFile: uiPreferencesFile, UITitlePrefix: c.uiTitlePrefix, + + DisableCSRFTokenChecks: c.disableCSRFTokenChecks, }) if err != nil { return errors.Wrap(err, "unable to initialize server") @@ -155,16 +162,9 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error return errors.Wrap(err, "error connecting to repository") } - mux := http.NewServeMux() + m := mux.NewRouter() - mux.Handle("/api/", srv.APIHandlers(c.serverStartLegacyRepositoryAPI)) - - if c.serverStartHTMLPath != "" { - fileServer := srv.ServeStaticFiles(http.Dir(c.serverStartHTMLPath)) - mux.Handle("/", fileServer) - } else if c.serverStartUI { - mux.Handle("/", srv.ServeStaticFiles(server.AssetFile())) - } + c.setupHandlers(srv, m) httpServer := &http.Server{ Addr: stripProtocol(c.sf.serverAddress), @@ -185,11 +185,11 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error // init prometheus after adding interceptors that require credentials, so that this // handler can be called without auth - if err = initPrometheus(mux); err != nil { + if err = initPrometheus(m); err != nil { return errors.Wrap(err, "error initializing Prometheus") } - var handler http.Handler = mux + var handler http.Handler = m if c.serverStartGRPC { handler = srv.GRPCRouterHandler(handler) @@ -222,7 +222,27 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error return errors.Wrap(srv.SetRepository(ctx, nil), "error setting active repository") } -func initPrometheus(mux *http.ServeMux) error { +func (c *commandServerStart) setupHandlers(srv *server.Server, m *mux.Router) { + if c.serverStartLegacyRepositoryAPI { + srv.SetupRepositoryAPIHandlers(m) + } + + if c.serverStartControlAPI { + srv.SetupControlAPIHandlers(m) + } + + if c.serverStartUI { + srv.SetupHTMLUIAPIHandlers(m) + + if c.serverStartHTMLPath != "" { + srv.ServeStaticFiles(m, http.Dir(c.serverStartHTMLPath)) + } else { + srv.ServeStaticFiles(m, server.AssetFile()) + } + } +} + +func initPrometheus(m *mux.Router) error { reg := prom.NewRegistry() if err := reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})); err != nil { return errors.Wrap(err, "error registering process collector") @@ -239,7 +259,7 @@ func initPrometheus(mux *http.ServeMux) error { return errors.Wrap(err, "unable to initialize prometheus exporter") } - mux.Handle("/metrics", pe) + m.Handle("/metrics", pe) return nil } diff --git a/internal/apiclient/apiclient.go b/internal/apiclient/apiclient.go index 1d0c30a3f..3c9905293 100644 --- a/internal/apiclient/apiclient.go +++ b/internal/apiclient/apiclient.go @@ -8,6 +8,8 @@ "io" "net/http" "net/http/cookiejar" + "regexp" + "strings" "github.com/pkg/errors" @@ -18,34 +20,69 @@ var log = logging.Module("client") +// CSRFTokenHeader is the name of CSRF token header that must be sent for most API calls. +// nolint:gosec +const CSRFTokenHeader = "X-Kopia-Csrf-Token" + // KopiaAPIClient provides helper methods for communicating with Kopia API server. type KopiaAPIClient struct { BaseURL string HTTPClient *http.Client + + CSRFToken string } // Get is a helper that performs HTTP GET on a URL with the specified suffix and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Get(ctx context.Context, urlSuffix string, onNotFound error, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodGet, c.BaseURL+urlSuffix, onNotFound, nil, respPayload) + return c.runRequest(ctx, http.MethodGet, c.actualURL(urlSuffix), onNotFound, nil, respPayload) } // Post is a helper that performs HTTP POST on a URL with the specified body from reqPayload and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Post(ctx context.Context, urlSuffix string, reqPayload, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodPost, c.BaseURL+urlSuffix, nil, reqPayload, respPayload) + return c.runRequest(ctx, http.MethodPost, c.actualURL(urlSuffix), nil, reqPayload, respPayload) } // Put is a helper that performs HTTP PUT on a URL with the specified body from reqPayload and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Put(ctx context.Context, urlSuffix string, reqPayload, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodPut, c.BaseURL+urlSuffix, nil, reqPayload, respPayload) + return c.runRequest(ctx, http.MethodPut, c.actualURL(urlSuffix), nil, reqPayload, respPayload) } // Delete is a helper that performs HTTP DELETE on a URL with the specified body from reqPayload and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Delete(ctx context.Context, urlSuffix string, onNotFound error, reqPayload, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodDelete, c.BaseURL+urlSuffix, onNotFound, reqPayload, respPayload) + return c.runRequest(ctx, http.MethodDelete, c.actualURL(urlSuffix), onNotFound, reqPayload, respPayload) +} + +// FetchCSRFTokenForTesting fetches the CSRF token and session cookie for use when making subsequent calls to the API. +// This simulates the browser behavior of downloading the "/" and is required to call the UI-only methods. +func (c *KopiaAPIClient) FetchCSRFTokenForTesting(ctx context.Context) error { + var b []byte + + if err := c.Get(ctx, "/", nil, &b); err != nil { + return err + } + + re := regexp.MustCompile(``) + + match := re.FindSubmatch(b) + if match == nil { + return errors.Errorf("CSRF token not found") + } + + c.CSRFToken = string(match[1]) + + return nil +} + +func (c *KopiaAPIClient) actualURL(suffix string) string { + if strings.HasPrefix(suffix, "/") { + return c.BaseURL + suffix + } + + return c.BaseURL + "/api/v1/" + suffix } func (c *KopiaAPIClient) runRequest(ctx context.Context, method, url string, notFoundError error, reqPayload, respPayload interface{}) error { @@ -59,6 +96,10 @@ func (c *KopiaAPIClient) runRequest(ctx context.Context, method, url string, not return errors.Wrap(err, "error creating request") } + if c.CSRFToken != "" { + req.Header.Add(CSRFTokenHeader, c.CSRFToken) + } + if contentType != "" { req.Header.Set("Content-Type", contentType) } @@ -165,11 +206,12 @@ func NewKopiaAPIClient(options Options) (*KopiaAPIClient, error) { } return &KopiaAPIClient{ - options.BaseURL + "/api/v1/", + options.BaseURL, &http.Client{ Jar: cj, Transport: transport, }, + "", }, nil } diff --git a/internal/server/api_cli_test.go b/internal/server/api_cli_test.go index 27c00eca8..ddad1f27e 100644 --- a/internal/server/api_cli_test.go +++ b/internal/server/api_cli_test.go @@ -23,6 +23,7 @@ func TestCLIAPI(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) resp := &serverapi.CLIInfo{} require.NoError(t, cli.Get(ctx, "cli", nil, resp)) diff --git a/internal/server/api_paths_test.go b/internal/server/api_paths_test.go index 66df7789a..ff64e99fc 100644 --- a/internal/server/api_paths_test.go +++ b/internal/server/api_paths_test.go @@ -23,6 +23,7 @@ func TestPathsAPI(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir0 := testutil.TempDirectory(t) diff --git a/internal/server/api_policies_test.go b/internal/server/api_policies_test.go index d745a3bcd..285808a0c 100644 --- a/internal/server/api_policies_test.go +++ b/internal/server/api_policies_test.go @@ -29,6 +29,7 @@ func TestPolicies(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir0 := testutil.TempDirectory(t) si0 := localSource(env, dir0) diff --git a/internal/server/api_sources_test.go b/internal/server/api_sources_test.go index 64022ec97..f9c124982 100644 --- a/internal/server/api_sources_test.go +++ b/internal/server/api_sources_test.go @@ -30,6 +30,7 @@ func TestSnapshotCounters(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir := testutil.TempDirectory(t) si := localSource(env, dir) @@ -103,6 +104,7 @@ func TestSourceRefreshesAfterPolicy(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir := testutil.TempDirectory(t) si := localSource(env, dir) diff --git a/internal/server/api_ui_pref_test.go b/internal/server/api_ui_pref_test.go index 8c8890519..21262a4ea 100644 --- a/internal/server/api_ui_pref_test.go +++ b/internal/server/api_ui_pref_test.go @@ -23,6 +23,8 @@ func TestUIPreferences(t *testing.T) { require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) + var p, p2 serverapi.UIPreferences require.NoError(t, cli.Get(ctx, "ui-preferences", nil, &p)) diff --git a/internal/server/server.go b/internal/server/server.go index c0a096fb4..0ff3be9d4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -45,6 +45,13 @@ kopiaAuthCookieIssuer = "kopia-server" ) +type csrfTokenOption int + +const ( + csrfTokenRequired csrfTokenOption = 1 + iota + csrfTokenNotRequired +) + type apiRequestFunc func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) // Server exposes simple HTTP API for programmatically accessing Kopia features. @@ -72,92 +79,81 @@ type Server struct { grpcServerState } -// APIHandlers handles API requests. -func (s *Server) APIHandlers(legacyAPI bool) http.Handler { - m := mux.NewRouter() - +// SetupHTMLUIAPIHandlers registers API requests required by the HTMLUI. +func (s *Server) SetupHTMLUIAPIHandlers(m *mux.Router) { // sources - m.HandleFunc("/api/v1/sources", s.handleAPI(requireUIUser, s.handleSourcesList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/sources", s.handleAPI(requireUIUser, s.handleSourcesCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/sources/upload", s.handleAPI(requireUIUser, s.handleUpload)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/sources/cancel", s.handleAPI(requireUIUser, s.handleCancel)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/sources", s.handleUI(s.handleSourcesList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/sources", s.handleUI(s.handleSourcesCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/sources/upload", s.handleUI(s.handleUpload)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/sources/cancel", s.handleUI(s.handleCancel)).Methods(http.MethodPost) // snapshots - m.HandleFunc("/api/v1/snapshots", s.handleAPI(requireUIUser, s.handleSnapshotList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyPut)).Methods(http.MethodPut) - m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyDelete)).Methods(http.MethodDelete) - m.HandleFunc("/api/v1/policy/resolve", s.handleAPI(requireUIUser, s.handlePolicyResolve)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/snapshots", s.handleUI(s.handleSnapshotList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyPut)).Methods(http.MethodPut) + m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyDelete)).Methods(http.MethodDelete) + m.HandleFunc("/api/v1/policy/resolve", s.handleUI(s.handlePolicyResolve)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/policies", s.handleUI(s.handlePolicyList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/refresh", s.handleUI(s.handleRefresh)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/objects/{objectID}", s.requireAuth(csrfTokenNotRequired, s.handleObjectGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/restore", s.handleUI(s.handleRestore)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/estimate", s.handleUI(s.handleEstimate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/paths/resolve", s.handleUI(s.handlePathResolve)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/cli", s.handleUI(s.handleCLIInfo)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/status", s.handleUI(s.handleRepoStatus)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/sync", s.handleUI(s.handleRepoSync)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/connect", s.handleUIPossiblyNotConnected(s.handleRepoConnect)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/exists", s.handleUIPossiblyNotConnected(s.handleRepoExists)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/create", s.handleUIPossiblyNotConnected(s.handleRepoCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/description", s.handleUI(s.handleRepoSetDescription)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/disconnect", s.handleUI(s.handleRepoDisconnect)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/algorithms", s.handleUIPossiblyNotConnected(s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/throttle", s.handleUI(s.handleRepoGetThrottle)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/throttle", s.handleUI(s.handleRepoSetThrottle)).Methods(http.MethodPut) - m.HandleFunc("/api/v1/policies", s.handleAPI(requireUIUser, s.handlePolicyList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/mounts", s.handleUI(s.handleMountCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleUI(s.handleMountDelete)).Methods(http.MethodDelete) + m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleUI(s.handleMountGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/mounts", s.handleUI(s.handleMountList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/refresh", s.handleAPI(anyAuthenticatedUser, s.handleRefresh)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/current-user", s.handleUIPossiblyNotConnected(s.handleCurrentUser)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/ui-preferences", s.handleUIPossiblyNotConnected(s.handleGetUIPreferences)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/ui-preferences", s.handleUIPossiblyNotConnected(s.handleSetUIPreferences)).Methods(http.MethodPut) - m.HandleFunc("/api/v1/objects/{objectID}", s.requireAuth(s.handleObjectGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/restore", s.handleAPI(requireUIUser, s.handleRestore)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/estimate", s.handleAPI(requireUIUser, s.handleEstimate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/tasks-summary", s.handleUI(s.handleTaskSummary)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks", s.handleUI(s.handleTaskList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks/{taskID}", s.handleUI(s.handleTaskInfo)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks/{taskID}/logs", s.handleUI(s.handleTaskLogs)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks/{taskID}/cancel", s.handleUI(s.handleTaskCancel)).Methods(http.MethodPost) +} - // path APIs - m.HandleFunc("/api/v1/paths/resolve", s.handleAPI(requireUIUser, s.handlePathResolve)).Methods(http.MethodPost) +// SetupRepositoryAPIHandlers registers HTTP repository API handlers. +func (s *Server) SetupRepositoryAPIHandlers(m *mux.Router) { + m.HandleFunc("/api/v1/flush", s.handleRepositoryAPI(anyAuthenticatedUser, s.handleFlush)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/parameters", s.handleRepositoryAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet) - // path APIs - m.HandleFunc("/api/v1/cli", s.handleAPI(requireUIUser, s.handleCLIInfo)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentInfo)).Methods(http.MethodGet).Queries("info", "1") + m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelAppend), s.handleContentPut)).Methods(http.MethodPut) - // methods that can be called by any authenticated user (UI or remote user). - m.HandleFunc("/api/v1/flush", s.handleAPI(anyAuthenticatedUser, s.handleFlush)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/status", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoStatus)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/repo/sync", s.handleAPI(anyAuthenticatedUser, s.handleRepoSync)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestDelete)).Methods(http.MethodDelete) + m.HandleFunc("/api/v1/manifests", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/manifests", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestList)).Methods(http.MethodGet) +} - m.HandleFunc("/api/v1/repo/connect", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoConnect)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/exists", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoExists)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/create", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/description", s.handleAPI(requireUIUser, s.handleRepoSetDescription)).Methods(http.MethodPost) - - m.HandleFunc("/api/v1/repo/disconnect", s.handleAPI(requireUIUser, s.handleRepoDisconnect)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/algorithms", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoGetThrottle)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoSetThrottle)).Methods(http.MethodPut) - - if legacyAPI { - m.HandleFunc("/api/v1/repo/parameters", s.handleAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet) - - m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentInfo)).Methods(http.MethodGet).Queries("info", "1") - m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelAppend), s.handleContentPut)).Methods(http.MethodPut) - - m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestDelete)).Methods(http.MethodDelete) - m.HandleFunc("/api/v1/manifests", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/manifests", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestList)).Methods(http.MethodGet) - } - - m.HandleFunc("/api/v1/mounts", s.handleAPI(requireUIUser, s.handleMountCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleAPI(requireUIUser, s.handleMountDelete)).Methods(http.MethodDelete) - m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleAPI(requireUIUser, s.handleMountGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/mounts", s.handleAPI(requireUIUser, s.handleMountList)).Methods(http.MethodGet) - - m.HandleFunc("/api/v1/current-user", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleCurrentUser)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/ui-preferences", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleGetUIPreferences)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/ui-preferences", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleSetUIPreferences)).Methods(http.MethodPut) - - m.HandleFunc("/api/v1/tasks-summary", s.handleAPI(requireUIUser, s.handleTaskSummary)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks", s.handleAPI(requireUIUser, s.handleTaskList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks/{taskID}", s.handleAPI(requireUIUser, s.handleTaskInfo)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks/{taskID}/logs", s.handleAPI(requireUIUser, s.handleTaskLogs)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks/{taskID}/cancel", s.handleAPI(requireUIUser, s.handleTaskCancel)).Methods(http.MethodPost) - - // server control API, requires authentication as `server-control`. - m.HandleFunc("/api/v1/control/status", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleRepoStatus)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/control/sources", s.handleAPI(requireServerControlUser, s.handleSourcesList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/control/flush", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleFlush)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/refresh", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleRefresh)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/shutdown", s.handleAPI(requireServerControlUser, s.handleShutdown)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/trigger-snapshot", s.handleAPI(requireServerControlUser, s.handleUpload)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/cancel-snapshot", s.handleAPI(requireServerControlUser, s.handleCancel)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/pause-source", s.handleAPI(requireServerControlUser, s.handlePause)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/resume-source", s.handleAPI(requireServerControlUser, s.handleResume)).Methods(http.MethodPost) - - return m +// SetupControlAPIHandlers registers control API handlers. +func (s *Server) SetupControlAPIHandlers(m *mux.Router) { + // server control API, requires authentication as `server-control` and no CSRF token. + m.HandleFunc("/api/v1/control/sources", s.handleServerControlAPI(s.handleSourcesList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/control/status", s.handleServerControlAPIPossiblyNotConnected(s.handleRepoStatus)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/control/flush", s.handleServerControlAPI(s.handleFlush)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/refresh", s.handleServerControlAPI(s.handleRefresh)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/shutdown", s.handleServerControlAPIPossiblyNotConnected(s.handleShutdown)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/trigger-snapshot", s.handleServerControlAPI(s.handleUpload)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/cancel-snapshot", s.handleServerControlAPI(s.handleCancel)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/pause-source", s.handleServerControlAPI(s.handlePause)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/resume-source", s.handleServerControlAPI(s.handleResume)).Methods(http.MethodPost) } func (s *Server) isAuthenticated(w http.ResponseWriter, r *http.Request) bool { @@ -198,6 +194,7 @@ func (s *Server) isAuthenticated(w http.ResponseWriter, r *http.Request) bool { Name: kopiaAuthCookie, Value: ac, Expires: now.Add(kopiaAuthCookieTTL), + Path: "/", }) } @@ -233,12 +230,19 @@ func (s *Server) generateShortTermAuthCookie(username string, now time.Time) (st }).SignedString(s.authCookieSigningKey) } -func (s *Server) requireAuth(f http.HandlerFunc) http.HandlerFunc { +func (s *Server) requireAuth(checkCSRFToken csrfTokenOption, f http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !s.isAuthenticated(w, r) { return } + if checkCSRFToken == csrfTokenRequired { + if !s.validateCSRFToken(r) { + http.Error(w, "Invalid or missing CSRF token.\n", http.StatusUnauthorized) + return + } + } + f(w, r) } } @@ -257,8 +261,8 @@ func (s *Server) httpAuthorizationInfo(ctx context.Context, r *http.Request) aut type isAuthorizedFunc func(s *Server, r *http.Request) bool -func (s *Server) handleAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc { - return s.handleAPIPossiblyNotConnected(isAuthorized, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { +func (s *Server) handleServerControlAPI(f apiRequestFunc) http.HandlerFunc { + return s.handleServerControlAPIPossiblyNotConnected(func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { if s.rep == nil { return nil, requestError(serverapi.ErrorNotConnected, "not connected") } @@ -267,8 +271,38 @@ func (s *Server) handleAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http }) } -func (s *Server) handleAPIPossiblyNotConnected(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc { - return s.requireAuth(func(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleServerControlAPIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(requireServerControlUser, csrfTokenNotRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + return f(ctx, r, body) + }) +} + +func (s *Server) handleRepositoryAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(isAuthorized, csrfTokenNotRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + if s.rep == nil { + return nil, requestError(serverapi.ErrorNotConnected, "not connected") + } + + return f(ctx, r, body) + }) +} + +func (s *Server) handleUI(f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(requireUIUser, csrfTokenRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + if s.rep == nil { + return nil, requestError(serverapi.ErrorNotConnected, "not connected") + } + + return f(ctx, r, body) + }) +} + +func (s *Server) handleUIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(requireUIUser, csrfTokenRequired, f) +} + +func (s *Server) handleRequestPossiblyNotConnected(isAuthorized isAuthorizedFunc, checkCSRFToken csrfTokenOption, f apiRequestFunc) http.HandlerFunc { + return s.requireAuth(checkCSRFToken, func(w http.ResponseWriter, r *http.Request) { // we must pre-read request body before acquiring the lock as it sometimes leads to deadlock // in HTTP/2 server. // See https://github.com/golang/go/issues/40816 @@ -682,11 +716,18 @@ func (s *Server) isKnownUIRoute(path string) bool { strings.HasPrefix(path, "/repo") } -func (s *Server) patchIndexBytes(b []byte) []byte { +func (s *Server) patchIndexBytes(sessionID string, b []byte) []byte { if s.options.UITitlePrefix != "" { b = bytes.ReplaceAll(b, []byte(""), []byte("<title>"+html.EscapeString(s.options.UITitlePrefix))) } + csrfToken := s.generateCSRFToken(sessionID) + + // insert <meta name="kopia-csrf-token" content="..." /> just before closing head tag. + b = bytes.ReplaceAll(b, + []byte(`</head>`), + []byte(`<meta name="kopia-csrf-token" content="`+csrfToken+`" /></head>`)) + return b } @@ -706,14 +747,14 @@ func maybeReadIndexBytes(fs http.FileSystem) []byte { return rd } -// ServeStaticFiles returns HTTP handler that serves static files and dynamically patches index.html to embed CSRF token, etc. -func (s *Server) ServeStaticFiles(fs http.FileSystem) http.Handler { +// ServeStaticFiles configures HTTP handler that serves static files and dynamically patches index.html to embed CSRF token, etc. +func (s *Server) ServeStaticFiles(m *mux.Router, fs http.FileSystem) { h := http.FileServer(fs) // read bytes from 'index.html'. indexBytes := maybeReadIndexBytes(fs) - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if s.isKnownUIRoute(r.URL.Path) { r2 := new(http.Request) *r2 = *r @@ -733,7 +774,22 @@ func (s *Server) ServeStaticFiles(fs http.FileSystem) http.Handler { } if r.URL.Path == "/" && indexBytes != nil { - http.ServeContent(w, r, "/", clock.Now(), bytes.NewReader(s.patchIndexBytes(indexBytes))) + var sessionID string + + if cookie, err := r.Cookie(kopiaSessionCookie); err == nil { + // already in a session, likely a new tab was opened + sessionID = cookie.Value + } else { + sessionID = uuid.NewString() + + http.SetCookie(w, &http.Cookie{ + Name: kopiaSessionCookie, + Value: sessionID, + Path: "/", + }) + } + + http.ServeContent(w, r, "/", clock.Now(), bytes.NewReader(s.patchIndexBytes(sessionID, indexBytes))) return } @@ -752,7 +808,7 @@ type Options struct { PasswordPersist passwordpersist.Strategy AuthCookieSigningKey string LogRequests bool - UIUser string // name of the user allowed to access the UI + UIUser string // name of the user allowed to access the UI API UIPreferencesFile string // name of the JSON file storing UI preferences ServerControlUser string // name of the user allowed to access the server control API DisableCSRFTokenChecks bool diff --git a/internal/server/server_authz_checks.go b/internal/server/server_authz_checks.go index 7d8ea73ea..60415fccd 100644 --- a/internal/server/server_authz_checks.go +++ b/internal/server/server_authz_checks.go @@ -2,16 +2,72 @@ import ( "context" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "io" "net/http" + "github.com/kopia/kopia/internal/apiclient" "github.com/kopia/kopia/internal/auth" ) +// kopiaSessionCookie is the name of the session cookie that Kopia server will generate for all +// UI sessions. +const kopiaSessionCookie = "Kopia-Session-Cookie" + +func (s *Server) generateCSRFToken(sessionID string) string { + h := hmac.New(sha256.New, s.authCookieSigningKey) + + if _, err := io.WriteString(h, sessionID); err != nil { + panic("io.WriteString() failed: " + err.Error()) + } + + return hex.EncodeToString(h.Sum(nil)) +} + +func (s *Server) validateCSRFToken(r *http.Request) bool { + if s.options.DisableCSRFTokenChecks { + return true + } + + ctx := r.Context() + path := r.URL.Path + + sessionCookie, err := r.Cookie(kopiaSessionCookie) + if err != nil { + log(ctx).Warnf("missing or invalid session cookie for %q: %v", path, err) + + return false + } + + validToken := s.generateCSRFToken(sessionCookie.Value) + + token := r.Header.Get(apiclient.CSRFTokenHeader) + if token == "" { + log(ctx).Warnf("missing CSRF token for %v", path) + return false + } + + if subtle.ConstantTimeCompare([]byte(validToken), []byte(token)) == 1 { + return true + } + + log(ctx).Warnf("got invalid CSRF token for %v: %v, want %v, session %v", path, token, validToken, sessionCookie.Value) + + return false +} + func requireUIUser(s *Server, r *http.Request) bool { if s.authenticator == nil { return true } + if s.options.UIUser == "" { + return false + } + user, _, _ := r.BasicAuth() return user == s.options.UIUser @@ -22,6 +78,10 @@ func requireServerControlUser(s *Server, r *http.Request) bool { return true } + if s.options.ServerControlUser == "" { + return false + } + user, _, _ := r.BasicAuth() return user == s.options.ServerControlUser diff --git a/internal/server/server_authz_checks_test.go b/internal/server/server_authz_checks_test.go new file mode 100644 index 000000000..cde09b493 --- /dev/null +++ b/internal/server/server_authz_checks_test.go @@ -0,0 +1,101 @@ +package server + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/kopia/kopia/internal/apiclient" +) + +func TestGenerateCSRFToken(t *testing.T) { + s1 := &Server{ + authCookieSigningKey: []byte("some-key"), + } + + s2 := &Server{ + authCookieSigningKey: []byte("some-other-key"), + } + + cases := []struct { + srv *Server + session string + wantToken string + }{ + // CSRF token is a stable function of session ID and per-server so we can hardcode it + {s1, "session1", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0"}, + {s2, "session1", "7fd10608493e844581247d44e61de56b80df83ecdae49891f150823f73524ef7"}, + {s1, "session2", "e3aeba64243485ac4664e27445f48711b987c3bf5c7e58d1b89eb1e2722fedcd"}, + {s2, "session2", "714a124df0f6b6e79500fb06a900e7870a94f50b6d1e532c92a0abc0c63146f8"}, + } + + for _, tc := range cases { + require.Equal(t, tc.wantToken, tc.srv.generateCSRFToken(tc.session)) + } +} + +func TestValidateCSRFToken(t *testing.T) { + s1 := &Server{ + authCookieSigningKey: []byte("some-key"), + } + + s2 := &Server{ + authCookieSigningKey: []byte("some-other-key"), + } + + s3 := &Server{ + options: Options{ + DisableCSRFTokenChecks: true, + }, + } + + cases := []struct { + srv *Server + session string + token string + valid bool + }{ + // valid + {s1, "session1", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0", true}, + {s2, "session1", "7fd10608493e844581247d44e61de56b80df83ecdae49891f150823f73524ef7", true}, + {s1, "session2", "e3aeba64243485ac4664e27445f48711b987c3bf5c7e58d1b89eb1e2722fedcd", true}, + {s2, "session2", "714a124df0f6b6e79500fb06a900e7870a94f50b6d1e532c92a0abc0c63146f8", true}, + + // invalid cases + {s1, "", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0", false}, // missing session ID + {s1, "session2", "", false}, // missing token + {s2, "session2", "invalid", false}, // invalid token + + // token is invalid but ignored, since 's3' does not validate tokens. + {s3, "", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0", true}, + {s3, "session2", "", true}, + {s3, "session2", "invalid-token", true}, + } + + ctx := context.Background() + + for i, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, "GET", "/somepath", http.NoBody) + require.NoError(t, err) + + if tc.session != "" { + req.AddCookie(&http.Cookie{ + Name: "Kopia-Session-Cookie", + Value: tc.session, + }) + } + + if tc.token != "" { + req.Header.Add(apiclient.CSRFTokenHeader, tc.token) + } + + require.Equal(t, tc.valid, tc.srv.validateCSRFToken(req)) + }) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index a774ec62e..df6a81943 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -12,6 +12,7 @@ "time" "github.com/google/go-cmp/cmp" + "github.com/gorilla/mux" "github.com/pkg/errors" "github.com/stretchr/testify/require" @@ -71,7 +72,13 @@ func startServer(t *testing.T, env *repotesting.Environment, tls bool) *repo.API asi := &repo.APIServerInfo{} - hs := httptest.NewUnstartedServer(s.GRPCRouterHandler(s.APIHandlers(true))) + m := mux.NewRouter() + s.SetupHTMLUIAPIHandlers(m) + s.SetupRepositoryAPIHandlers(m) + s.SetupControlAPIHandlers(m) + s.ServeStaticFiles(m, server.AssetFile()) + + hs := httptest.NewUnstartedServer(s.GRPCRouterHandler(m)) if tls { hs.EnableHTTP2 = true hs.StartTLS() @@ -131,6 +138,7 @@ func TestGPRServer_AuthenticationError(t *testing.T) { } } +// nolint:gocyclo func TestServerUIAccessDeniedToRemoteUser(t *testing.T) { ctx, env := repotesting.NewEnvironment(t, repotesting.FormatNotImportant) si := startServer(t, env, true) @@ -141,6 +149,14 @@ func TestServerUIAccessDeniedToRemoteUser(t *testing.T) { Username: testUsername + "@" + testHostname, Password: testPassword, }) + require.NoError(t, err) + + uiUserWithoutCSRFToken, err := apiclient.NewKopiaAPIClient(apiclient.Options{ + BaseURL: si.BaseURL, + TrustedServerCertificateFingerprint: si.TrustedServerCertificateFingerprint, + Username: testUIUsername, + Password: testUIPassword, + }) require.NoError(t, err) @@ -153,7 +169,10 @@ func TestServerUIAccessDeniedToRemoteUser(t *testing.T) { require.NoError(t, err) - // examples of URLs and expected statuses returned when UI user calls them, but which must return 403 when + require.NoError(t, uiUserClient.FetchCSRFTokenForTesting(ctx)) + // do not call uiUserWithoutCSRFToken.FetchCSRFTokenForTesting() + + // examples of URLs and expected statuses returned when UI user calls them, but which must return 401 due to missing CSRF token // remote user calls them. getUrls := map[string]int{ "mounts": http.StatusOK, @@ -171,8 +190,34 @@ func TestServerUIAccessDeniedToRemoteUser(t *testing.T) { t.Run(urlSuffix, func(t *testing.T) { var hsr apiclient.HTTPStatusError - if err := remoteUserClient.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || hsr.HTTPStatusCode != http.StatusForbidden { - t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, http.StatusForbidden) + wantFailure := http.StatusUnauthorized // 401 + + if urlSuffix == "objects/abcd" { + // this is a special one that does not require CSRF token but will still fail with 403 + wantFailure = http.StatusForbidden + } + + if err := remoteUserClient.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || (hsr.HTTPStatusCode != wantFailure) { + t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, wantFailure) + } + + if wantStatus == http.StatusOK { + if err := uiUserClient.Get(ctx, urlSuffix, nil, nil); err != nil { + t.Fatalf("expected success, got %v", err) + } + } else if err := uiUserClient.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || hsr.HTTPStatusCode != wantStatus { + t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, wantStatus) + } + + // objects/abcd does not require CSRF token so will fail with 404 instead of 403. + // This is fine since this is a side-effect-free GET method so same-origin policy + // will protect access to data. + if urlSuffix == "objects/abcd" { + wantFailure = http.StatusNotFound + } + + if err := uiUserWithoutCSRFToken.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || (hsr.HTTPStatusCode != wantFailure) { + t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, wantFailure) } if wantStatus == http.StatusOK { @@ -213,7 +258,6 @@ func remoteRepositoryTest(ctx context.Context, t *testing.T, rep repo.Repository }, func(ctx context.Context, w repo.RepositoryWriter) error { mustGetObjectNotFound(ctx, t, w, "abcd") mustGetManifestNotFound(ctx, t, w, "mnosuchmanifest") - mustManifestNotFound(t, w.DeleteManifest(ctx, manifestID2)) mustListSnapshotCount(ctx, t, w, 0) result = mustWriteObject(ctx, t, w, written) diff --git a/tests/end_to_end_test/server_start_test.go b/tests/end_to_end_test/server_start_test.go index e0595ccd6..46ba6ba08 100644 --- a/tests/end_to_end_test/server_start_test.go +++ b/tests/end_to_end_test/server_start_test.go @@ -98,6 +98,7 @@ func TestServerStart(t *testing.T) { LogRequests: true, }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) controlClient, err := apiclient.NewKopiaAPIClient(apiclient.Options{ BaseURL: sp.baseURL, @@ -113,8 +114,7 @@ func TestServerStart(t *testing.T) { waitUntilServerStarted(ctx, t, controlClient) verifyUIServedWithCorrectTitle(t, cli, sp) - st := verifyServerConnected(t, controlClient, true) - require.Equal(t, "filesystem", st.Storage) + verifyServerConnected(t, controlClient, true) limits, err := serverapi.GetThrottlingLimits(ctx, cli) require.NoError(t, err) @@ -234,6 +234,7 @@ func TestServerCreateAndConnectViaAPI(t *testing.T) { TrustedServerCertificateFingerprint: sp.sha256Fingerprint, }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) controlClient, err := apiclient.NewKopiaAPIClient(apiclient.Options{ BaseURL: sp.baseURL, @@ -306,20 +307,11 @@ func TestConnectToExistingRepositoryViaAPI(t *testing.T) { "--override-hostname=fake-hostname", "--override-username=fake-username") t.Logf("detected server parameters %#v", sp) - cli, err := apiclient.NewKopiaAPIClient(apiclient.Options{ - BaseURL: sp.baseURL, - Username: "kopia", - Password: sp.password, - TrustedServerCertificateFingerprint: sp.sha256Fingerprint, - }) - require.NoError(t, err) - controlClient, err := apiclient.NewKopiaAPIClient(apiclient.Options{ BaseURL: sp.baseURL, Username: "server-control", Password: sp.serverControlPassword, TrustedServerCertificateFingerprint: sp.sha256Fingerprint, - LogRequests: true, }) require.NoError(t, err) @@ -328,6 +320,15 @@ func TestConnectToExistingRepositoryViaAPI(t *testing.T) { waitUntilServerStarted(ctx, t, controlClient) verifyServerConnected(t, controlClient, false) + cli, err := apiclient.NewKopiaAPIClient(apiclient.Options{ + BaseURL: sp.baseURL, + Username: "kopia", + Password: sp.password, + TrustedServerCertificateFingerprint: sp.sha256Fingerprint, + }) + require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) + if err = serverapi.ConnectToRepository(ctx, cli, &serverapi.ConnectRepositoryRequest{ Password: testenv.TestRepoPassword, Storage: connInfo, @@ -533,10 +534,21 @@ func verifyUIServedWithCorrectTitle(t *testing.T, cli *apiclient.KopiaAPIClient, func waitUntilServerStarted(ctx context.Context, t *testing.T, cli *apiclient.KopiaAPIClient) { t.Helper() - if err := retry.PeriodicallyNoValue(ctx, 1*time.Second, 180, "wait for server start", func() error { + require.NoError(t, retry.PeriodicallyNoValue(ctx, 1*time.Second, 180, "wait for server start", func() error { _, err := serverapi.Status(testlogging.Context(t), cli) return err - }, retry.Always); err != nil { - t.Fatalf("server failed to start") - } + }, func(err error) bool { + var hs apiclient.HTTPStatusError + + if errors.As(err, &hs) { + switch hs.HTTPStatusCode { + case http.StatusBadRequest: + return false + case http.StatusForbidden: + return false + } + } + + return true + })) }