diff --git a/cli/app.go b/cli/app.go index 72e29d5fb..b8a1b411b 100644 --- a/cli/app.go +++ b/cli/app.go @@ -12,6 +12,7 @@ "github.com/alecthomas/kingpin" "github.com/fatih/color" + "github.com/gorilla/mux" "github.com/mattn/go-colorable" "github.com/pkg/errors" @@ -425,21 +426,21 @@ func (c *App) maybeRepositoryAction(act func(ctx context.Context, rep repo.Repos defer gather.DumpStats(ctx) if c.metricsListenAddr != "" { - mux := http.NewServeMux() - if err := initPrometheus(mux); err != nil { + m := mux.NewRouter() + if err := initPrometheus(m); err != nil { return errors.Wrap(err, "unable to initialize prometheus.") } if c.enablePProf { - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + m.HandleFunc("/debug/pprof/", pprof.Index) + m.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + m.HandleFunc("/debug/pprof/profile", pprof.Profile) + m.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + m.HandleFunc("/debug/pprof/trace", pprof.Trace) } log(ctx).Infof("starting prometheus metrics on %v", c.metricsListenAddr) - go http.ListenAndServe(c.metricsListenAddr, mux) // nolint:errcheck + go http.ListenAndServe(c.metricsListenAddr, m) // nolint:errcheck } memtrack.Dump(ctx, "before openRepository") diff --git a/cli/command_server_start.go b/cli/command_server_start.go index b71a2b79e..3670c00d6 100644 --- a/cli/command_server_start.go +++ b/cli/command_server_start.go @@ -14,6 +14,7 @@ "time" "contrib.go.opencensus.io/exporter/prometheus" + "github.com/gorilla/mux" "github.com/pkg/errors" prom "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" @@ -30,10 +31,11 @@ type commandServerStart struct { co connectOptions serverStartHTMLPath string - serverStartUI bool + serverStartUI bool serverStartLegacyRepositoryAPI bool serverStartGRPC bool + serverStartControlAPI bool serverStartRefreshInterval time.Duration serverStartInsecure bool @@ -63,6 +65,8 @@ type commandServerStart struct { logServerRequests bool + disableCSRFTokenChecks bool // disable CSRF token checks - used for development/debugging only + sf serverFlags svc advancedAppServices out textOutput @@ -75,6 +79,7 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent cmd.Flag("legacy-api", "Start the legacy server API").Default("true").BoolVar(&c.serverStartLegacyRepositoryAPI) cmd.Flag("grpc", "Start the GRPC server").Default("true").BoolVar(&c.serverStartGRPC) + cmd.Flag("control-api", "Start the control API").Default("true").BoolVar(&c.serverStartControlAPI) cmd.Flag("refresh-interval", "Frequency for refreshing repository status").Default("300s").DurationVar(&c.serverStartRefreshInterval) cmd.Flag("insecure", "Allow insecure configurations (do not use in production)").Hidden().BoolVar(&c.serverStartInsecure) @@ -104,6 +109,7 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent cmd.Flag("ui-preferences-file", "Path to JSON file storing UI preferences").StringVar(&c.uiPreferencesFile) cmd.Flag("log-server-requests", "Log server requests").Hidden().BoolVar(&c.logServerRequests) + cmd.Flag("disable-csrf-token-checks", "Disable CSRF token").Hidden().BoolVar(&c.disableCSRFTokenChecks) c.sf.setup(cmd) c.co.setup(cmd) @@ -116,7 +122,6 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent })) } -// nolint:funlen func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error { authn, err := c.getAuthenticator(ctx) if err != nil { @@ -142,6 +147,8 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error PasswordPersist: c.svc.passwordPersistenceStrategy(), UIPreferencesFile: uiPreferencesFile, UITitlePrefix: c.uiTitlePrefix, + + DisableCSRFTokenChecks: c.disableCSRFTokenChecks, }) if err != nil { return errors.Wrap(err, "unable to initialize server") @@ -155,16 +162,9 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error return errors.Wrap(err, "error connecting to repository") } - mux := http.NewServeMux() + m := mux.NewRouter() - mux.Handle("/api/", srv.APIHandlers(c.serverStartLegacyRepositoryAPI)) - - if c.serverStartHTMLPath != "" { - fileServer := srv.ServeStaticFiles(http.Dir(c.serverStartHTMLPath)) - mux.Handle("/", fileServer) - } else if c.serverStartUI { - mux.Handle("/", srv.ServeStaticFiles(server.AssetFile())) - } + c.setupHandlers(srv, m) httpServer := &http.Server{ Addr: stripProtocol(c.sf.serverAddress), @@ -185,11 +185,11 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error // init prometheus after adding interceptors that require credentials, so that this // handler can be called without auth - if err = initPrometheus(mux); err != nil { + if err = initPrometheus(m); err != nil { return errors.Wrap(err, "error initializing Prometheus") } - var handler http.Handler = mux + var handler http.Handler = m if c.serverStartGRPC { handler = srv.GRPCRouterHandler(handler) @@ -222,7 +222,27 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error return errors.Wrap(srv.SetRepository(ctx, nil), "error setting active repository") } -func initPrometheus(mux *http.ServeMux) error { +func (c *commandServerStart) setupHandlers(srv *server.Server, m *mux.Router) { + if c.serverStartLegacyRepositoryAPI { + srv.SetupRepositoryAPIHandlers(m) + } + + if c.serverStartControlAPI { + srv.SetupControlAPIHandlers(m) + } + + if c.serverStartUI { + srv.SetupHTMLUIAPIHandlers(m) + + if c.serverStartHTMLPath != "" { + srv.ServeStaticFiles(m, http.Dir(c.serverStartHTMLPath)) + } else { + srv.ServeStaticFiles(m, server.AssetFile()) + } + } +} + +func initPrometheus(m *mux.Router) error { reg := prom.NewRegistry() if err := reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})); err != nil { return errors.Wrap(err, "error registering process collector") @@ -239,7 +259,7 @@ func initPrometheus(mux *http.ServeMux) error { return errors.Wrap(err, "unable to initialize prometheus exporter") } - mux.Handle("/metrics", pe) + m.Handle("/metrics", pe) return nil } diff --git a/internal/apiclient/apiclient.go b/internal/apiclient/apiclient.go index 1d0c30a3f..3c9905293 100644 --- a/internal/apiclient/apiclient.go +++ b/internal/apiclient/apiclient.go @@ -8,6 +8,8 @@ "io" "net/http" "net/http/cookiejar" + "regexp" + "strings" "github.com/pkg/errors" @@ -18,34 +20,69 @@ var log = logging.Module("client") +// CSRFTokenHeader is the name of CSRF token header that must be sent for most API calls. +// nolint:gosec +const CSRFTokenHeader = "X-Kopia-Csrf-Token" + // KopiaAPIClient provides helper methods for communicating with Kopia API server. type KopiaAPIClient struct { BaseURL string HTTPClient *http.Client + + CSRFToken string } // Get is a helper that performs HTTP GET on a URL with the specified suffix and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Get(ctx context.Context, urlSuffix string, onNotFound error, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodGet, c.BaseURL+urlSuffix, onNotFound, nil, respPayload) + return c.runRequest(ctx, http.MethodGet, c.actualURL(urlSuffix), onNotFound, nil, respPayload) } // Post is a helper that performs HTTP POST on a URL with the specified body from reqPayload and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Post(ctx context.Context, urlSuffix string, reqPayload, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodPost, c.BaseURL+urlSuffix, nil, reqPayload, respPayload) + return c.runRequest(ctx, http.MethodPost, c.actualURL(urlSuffix), nil, reqPayload, respPayload) } // Put is a helper that performs HTTP PUT on a URL with the specified body from reqPayload and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Put(ctx context.Context, urlSuffix string, reqPayload, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodPut, c.BaseURL+urlSuffix, nil, reqPayload, respPayload) + return c.runRequest(ctx, http.MethodPut, c.actualURL(urlSuffix), nil, reqPayload, respPayload) } // Delete is a helper that performs HTTP DELETE on a URL with the specified body from reqPayload and decodes the response // onto respPayload which must be a pointer to byte slice or JSON-serializable structure. func (c *KopiaAPIClient) Delete(ctx context.Context, urlSuffix string, onNotFound error, reqPayload, respPayload interface{}) error { - return c.runRequest(ctx, http.MethodDelete, c.BaseURL+urlSuffix, onNotFound, reqPayload, respPayload) + return c.runRequest(ctx, http.MethodDelete, c.actualURL(urlSuffix), onNotFound, reqPayload, respPayload) +} + +// FetchCSRFTokenForTesting fetches the CSRF token and session cookie for use when making subsequent calls to the API. +// This simulates the browser behavior of downloading the "/" and is required to call the UI-only methods. +func (c *KopiaAPIClient) FetchCSRFTokenForTesting(ctx context.Context) error { + var b []byte + + if err := c.Get(ctx, "/", nil, &b); err != nil { + return err + } + + re := regexp.MustCompile(``) + + match := re.FindSubmatch(b) + if match == nil { + return errors.Errorf("CSRF token not found") + } + + c.CSRFToken = string(match[1]) + + return nil +} + +func (c *KopiaAPIClient) actualURL(suffix string) string { + if strings.HasPrefix(suffix, "/") { + return c.BaseURL + suffix + } + + return c.BaseURL + "/api/v1/" + suffix } func (c *KopiaAPIClient) runRequest(ctx context.Context, method, url string, notFoundError error, reqPayload, respPayload interface{}) error { @@ -59,6 +96,10 @@ func (c *KopiaAPIClient) runRequest(ctx context.Context, method, url string, not return errors.Wrap(err, "error creating request") } + if c.CSRFToken != "" { + req.Header.Add(CSRFTokenHeader, c.CSRFToken) + } + if contentType != "" { req.Header.Set("Content-Type", contentType) } @@ -165,11 +206,12 @@ func NewKopiaAPIClient(options Options) (*KopiaAPIClient, error) { } return &KopiaAPIClient{ - options.BaseURL + "/api/v1/", + options.BaseURL, &http.Client{ Jar: cj, Transport: transport, }, + "", }, nil } diff --git a/internal/server/api_cli_test.go b/internal/server/api_cli_test.go index 27c00eca8..ddad1f27e 100644 --- a/internal/server/api_cli_test.go +++ b/internal/server/api_cli_test.go @@ -23,6 +23,7 @@ func TestCLIAPI(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) resp := &serverapi.CLIInfo{} require.NoError(t, cli.Get(ctx, "cli", nil, resp)) diff --git a/internal/server/api_paths_test.go b/internal/server/api_paths_test.go index 66df7789a..ff64e99fc 100644 --- a/internal/server/api_paths_test.go +++ b/internal/server/api_paths_test.go @@ -23,6 +23,7 @@ func TestPathsAPI(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir0 := testutil.TempDirectory(t) diff --git a/internal/server/api_policies_test.go b/internal/server/api_policies_test.go index d745a3bcd..285808a0c 100644 --- a/internal/server/api_policies_test.go +++ b/internal/server/api_policies_test.go @@ -29,6 +29,7 @@ func TestPolicies(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir0 := testutil.TempDirectory(t) si0 := localSource(env, dir0) diff --git a/internal/server/api_sources_test.go b/internal/server/api_sources_test.go index 64022ec97..f9c124982 100644 --- a/internal/server/api_sources_test.go +++ b/internal/server/api_sources_test.go @@ -30,6 +30,7 @@ func TestSnapshotCounters(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir := testutil.TempDirectory(t) si := localSource(env, dir) @@ -103,6 +104,7 @@ func TestSourceRefreshesAfterPolicy(t *testing.T) { }) require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) dir := testutil.TempDirectory(t) si := localSource(env, dir) diff --git a/internal/server/api_ui_pref_test.go b/internal/server/api_ui_pref_test.go index 8c8890519..21262a4ea 100644 --- a/internal/server/api_ui_pref_test.go +++ b/internal/server/api_ui_pref_test.go @@ -23,6 +23,8 @@ func TestUIPreferences(t *testing.T) { require.NoError(t, err) + require.NoError(t, cli.FetchCSRFTokenForTesting(ctx)) + var p, p2 serverapi.UIPreferences require.NoError(t, cli.Get(ctx, "ui-preferences", nil, &p)) diff --git a/internal/server/server.go b/internal/server/server.go index c0a096fb4..0ff3be9d4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -45,6 +45,13 @@ kopiaAuthCookieIssuer = "kopia-server" ) +type csrfTokenOption int + +const ( + csrfTokenRequired csrfTokenOption = 1 + iota + csrfTokenNotRequired +) + type apiRequestFunc func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) // Server exposes simple HTTP API for programmatically accessing Kopia features. @@ -72,92 +79,81 @@ type Server struct { grpcServerState } -// APIHandlers handles API requests. -func (s *Server) APIHandlers(legacyAPI bool) http.Handler { - m := mux.NewRouter() - +// SetupHTMLUIAPIHandlers registers API requests required by the HTMLUI. +func (s *Server) SetupHTMLUIAPIHandlers(m *mux.Router) { // sources - m.HandleFunc("/api/v1/sources", s.handleAPI(requireUIUser, s.handleSourcesList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/sources", s.handleAPI(requireUIUser, s.handleSourcesCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/sources/upload", s.handleAPI(requireUIUser, s.handleUpload)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/sources/cancel", s.handleAPI(requireUIUser, s.handleCancel)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/sources", s.handleUI(s.handleSourcesList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/sources", s.handleUI(s.handleSourcesCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/sources/upload", s.handleUI(s.handleUpload)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/sources/cancel", s.handleUI(s.handleCancel)).Methods(http.MethodPost) // snapshots - m.HandleFunc("/api/v1/snapshots", s.handleAPI(requireUIUser, s.handleSnapshotList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyPut)).Methods(http.MethodPut) - m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyDelete)).Methods(http.MethodDelete) - m.HandleFunc("/api/v1/policy/resolve", s.handleAPI(requireUIUser, s.handlePolicyResolve)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/snapshots", s.handleUI(s.handleSnapshotList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyPut)).Methods(http.MethodPut) + m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyDelete)).Methods(http.MethodDelete) + m.HandleFunc("/api/v1/policy/resolve", s.handleUI(s.handlePolicyResolve)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/policies", s.handleUI(s.handlePolicyList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/refresh", s.handleUI(s.handleRefresh)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/objects/{objectID}", s.requireAuth(csrfTokenNotRequired, s.handleObjectGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/restore", s.handleUI(s.handleRestore)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/estimate", s.handleUI(s.handleEstimate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/paths/resolve", s.handleUI(s.handlePathResolve)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/cli", s.handleUI(s.handleCLIInfo)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/status", s.handleUI(s.handleRepoStatus)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/sync", s.handleUI(s.handleRepoSync)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/connect", s.handleUIPossiblyNotConnected(s.handleRepoConnect)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/exists", s.handleUIPossiblyNotConnected(s.handleRepoExists)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/create", s.handleUIPossiblyNotConnected(s.handleRepoCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/description", s.handleUI(s.handleRepoSetDescription)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/disconnect", s.handleUI(s.handleRepoDisconnect)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/algorithms", s.handleUIPossiblyNotConnected(s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/throttle", s.handleUI(s.handleRepoGetThrottle)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/repo/throttle", s.handleUI(s.handleRepoSetThrottle)).Methods(http.MethodPut) - m.HandleFunc("/api/v1/policies", s.handleAPI(requireUIUser, s.handlePolicyList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/mounts", s.handleUI(s.handleMountCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleUI(s.handleMountDelete)).Methods(http.MethodDelete) + m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleUI(s.handleMountGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/mounts", s.handleUI(s.handleMountList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/refresh", s.handleAPI(anyAuthenticatedUser, s.handleRefresh)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/current-user", s.handleUIPossiblyNotConnected(s.handleCurrentUser)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/ui-preferences", s.handleUIPossiblyNotConnected(s.handleGetUIPreferences)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/ui-preferences", s.handleUIPossiblyNotConnected(s.handleSetUIPreferences)).Methods(http.MethodPut) - m.HandleFunc("/api/v1/objects/{objectID}", s.requireAuth(s.handleObjectGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/restore", s.handleAPI(requireUIUser, s.handleRestore)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/estimate", s.handleAPI(requireUIUser, s.handleEstimate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/tasks-summary", s.handleUI(s.handleTaskSummary)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks", s.handleUI(s.handleTaskList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks/{taskID}", s.handleUI(s.handleTaskInfo)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks/{taskID}/logs", s.handleUI(s.handleTaskLogs)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/tasks/{taskID}/cancel", s.handleUI(s.handleTaskCancel)).Methods(http.MethodPost) +} - // path APIs - m.HandleFunc("/api/v1/paths/resolve", s.handleAPI(requireUIUser, s.handlePathResolve)).Methods(http.MethodPost) +// SetupRepositoryAPIHandlers registers HTTP repository API handlers. +func (s *Server) SetupRepositoryAPIHandlers(m *mux.Router) { + m.HandleFunc("/api/v1/flush", s.handleRepositoryAPI(anyAuthenticatedUser, s.handleFlush)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/repo/parameters", s.handleRepositoryAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet) - // path APIs - m.HandleFunc("/api/v1/cli", s.handleAPI(requireUIUser, s.handleCLIInfo)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentInfo)).Methods(http.MethodGet).Queries("info", "1") + m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelAppend), s.handleContentPut)).Methods(http.MethodPut) - // methods that can be called by any authenticated user (UI or remote user). - m.HandleFunc("/api/v1/flush", s.handleAPI(anyAuthenticatedUser, s.handleFlush)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/status", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoStatus)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/repo/sync", s.handleAPI(anyAuthenticatedUser, s.handleRepoSync)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestGet)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestDelete)).Methods(http.MethodDelete) + m.HandleFunc("/api/v1/manifests", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestCreate)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/manifests", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestList)).Methods(http.MethodGet) +} - m.HandleFunc("/api/v1/repo/connect", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoConnect)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/exists", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoExists)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/create", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/description", s.handleAPI(requireUIUser, s.handleRepoSetDescription)).Methods(http.MethodPost) - - m.HandleFunc("/api/v1/repo/disconnect", s.handleAPI(requireUIUser, s.handleRepoDisconnect)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/repo/algorithms", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoGetThrottle)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoSetThrottle)).Methods(http.MethodPut) - - if legacyAPI { - m.HandleFunc("/api/v1/repo/parameters", s.handleAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet) - - m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentInfo)).Methods(http.MethodGet).Queries("info", "1") - m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelAppend), s.handleContentPut)).Methods(http.MethodPut) - - m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestDelete)).Methods(http.MethodDelete) - m.HandleFunc("/api/v1/manifests", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/manifests", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestList)).Methods(http.MethodGet) - } - - m.HandleFunc("/api/v1/mounts", s.handleAPI(requireUIUser, s.handleMountCreate)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleAPI(requireUIUser, s.handleMountDelete)).Methods(http.MethodDelete) - m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleAPI(requireUIUser, s.handleMountGet)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/mounts", s.handleAPI(requireUIUser, s.handleMountList)).Methods(http.MethodGet) - - m.HandleFunc("/api/v1/current-user", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleCurrentUser)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/ui-preferences", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleGetUIPreferences)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/ui-preferences", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleSetUIPreferences)).Methods(http.MethodPut) - - m.HandleFunc("/api/v1/tasks-summary", s.handleAPI(requireUIUser, s.handleTaskSummary)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks", s.handleAPI(requireUIUser, s.handleTaskList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks/{taskID}", s.handleAPI(requireUIUser, s.handleTaskInfo)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks/{taskID}/logs", s.handleAPI(requireUIUser, s.handleTaskLogs)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/tasks/{taskID}/cancel", s.handleAPI(requireUIUser, s.handleTaskCancel)).Methods(http.MethodPost) - - // server control API, requires authentication as `server-control`. - m.HandleFunc("/api/v1/control/status", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleRepoStatus)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/control/sources", s.handleAPI(requireServerControlUser, s.handleSourcesList)).Methods(http.MethodGet) - m.HandleFunc("/api/v1/control/flush", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleFlush)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/refresh", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleRefresh)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/shutdown", s.handleAPI(requireServerControlUser, s.handleShutdown)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/trigger-snapshot", s.handleAPI(requireServerControlUser, s.handleUpload)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/cancel-snapshot", s.handleAPI(requireServerControlUser, s.handleCancel)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/pause-source", s.handleAPI(requireServerControlUser, s.handlePause)).Methods(http.MethodPost) - m.HandleFunc("/api/v1/control/resume-source", s.handleAPI(requireServerControlUser, s.handleResume)).Methods(http.MethodPost) - - return m +// SetupControlAPIHandlers registers control API handlers. +func (s *Server) SetupControlAPIHandlers(m *mux.Router) { + // server control API, requires authentication as `server-control` and no CSRF token. + m.HandleFunc("/api/v1/control/sources", s.handleServerControlAPI(s.handleSourcesList)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/control/status", s.handleServerControlAPIPossiblyNotConnected(s.handleRepoStatus)).Methods(http.MethodGet) + m.HandleFunc("/api/v1/control/flush", s.handleServerControlAPI(s.handleFlush)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/refresh", s.handleServerControlAPI(s.handleRefresh)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/shutdown", s.handleServerControlAPIPossiblyNotConnected(s.handleShutdown)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/trigger-snapshot", s.handleServerControlAPI(s.handleUpload)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/cancel-snapshot", s.handleServerControlAPI(s.handleCancel)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/pause-source", s.handleServerControlAPI(s.handlePause)).Methods(http.MethodPost) + m.HandleFunc("/api/v1/control/resume-source", s.handleServerControlAPI(s.handleResume)).Methods(http.MethodPost) } func (s *Server) isAuthenticated(w http.ResponseWriter, r *http.Request) bool { @@ -198,6 +194,7 @@ func (s *Server) isAuthenticated(w http.ResponseWriter, r *http.Request) bool { Name: kopiaAuthCookie, Value: ac, Expires: now.Add(kopiaAuthCookieTTL), + Path: "/", }) } @@ -233,12 +230,19 @@ func (s *Server) generateShortTermAuthCookie(username string, now time.Time) (st }).SignedString(s.authCookieSigningKey) } -func (s *Server) requireAuth(f http.HandlerFunc) http.HandlerFunc { +func (s *Server) requireAuth(checkCSRFToken csrfTokenOption, f http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !s.isAuthenticated(w, r) { return } + if checkCSRFToken == csrfTokenRequired { + if !s.validateCSRFToken(r) { + http.Error(w, "Invalid or missing CSRF token.\n", http.StatusUnauthorized) + return + } + } + f(w, r) } } @@ -257,8 +261,8 @@ func (s *Server) httpAuthorizationInfo(ctx context.Context, r *http.Request) aut type isAuthorizedFunc func(s *Server, r *http.Request) bool -func (s *Server) handleAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc { - return s.handleAPIPossiblyNotConnected(isAuthorized, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { +func (s *Server) handleServerControlAPI(f apiRequestFunc) http.HandlerFunc { + return s.handleServerControlAPIPossiblyNotConnected(func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { if s.rep == nil { return nil, requestError(serverapi.ErrorNotConnected, "not connected") } @@ -267,8 +271,38 @@ func (s *Server) handleAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http }) } -func (s *Server) handleAPIPossiblyNotConnected(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc { - return s.requireAuth(func(w http.ResponseWriter, r *http.Request) { +func (s *Server) handleServerControlAPIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(requireServerControlUser, csrfTokenNotRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + return f(ctx, r, body) + }) +} + +func (s *Server) handleRepositoryAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(isAuthorized, csrfTokenNotRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + if s.rep == nil { + return nil, requestError(serverapi.ErrorNotConnected, "not connected") + } + + return f(ctx, r, body) + }) +} + +func (s *Server) handleUI(f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(requireUIUser, csrfTokenRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) { + if s.rep == nil { + return nil, requestError(serverapi.ErrorNotConnected, "not connected") + } + + return f(ctx, r, body) + }) +} + +func (s *Server) handleUIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc { + return s.handleRequestPossiblyNotConnected(requireUIUser, csrfTokenRequired, f) +} + +func (s *Server) handleRequestPossiblyNotConnected(isAuthorized isAuthorizedFunc, checkCSRFToken csrfTokenOption, f apiRequestFunc) http.HandlerFunc { + return s.requireAuth(checkCSRFToken, func(w http.ResponseWriter, r *http.Request) { // we must pre-read request body before acquiring the lock as it sometimes leads to deadlock // in HTTP/2 server. // See https://github.com/golang/go/issues/40816 @@ -682,11 +716,18 @@ func (s *Server) isKnownUIRoute(path string) bool { strings.HasPrefix(path, "/repo") } -func (s *Server) patchIndexBytes(b []byte) []byte { +func (s *Server) patchIndexBytes(sessionID string, b []byte) []byte { if s.options.UITitlePrefix != "" { b = bytes.ReplaceAll(b, []byte("