fix(security): prevent cross-site request forgery in the UI website (#1653)

* fix(security): prevent cross-site request forgery in the UI website

This fixes a [cross-site request forgery (CSRF)](https://en.wikipedia.org/wiki/Cross-site_request_forgery)
vulnerability in self-hosted UI for Kopia server.

The vulnerability allows potential attacker to make unauthorized API
calls against a running Kopia server. It requires an attacker to trick
the user into visiting a malicious website while also logged into a
Kopia website.

The vulnerability only affected self-hosted Kopia servers with UI. The
following configurations were not vulnerable:

* Kopia Repository Server without UI
* KopiaUI (desktop app)
* command-line usage of `kopia`

All users are strongly recommended to upgrade at the earliest
convenience.

* pr feedback
This commit is contained in:
Jarek Kowalski
2022-01-13 11:31:51 -08:00
committed by GitHub
parent 2385ab19c9
commit 3d58566644
13 changed files with 478 additions and 135 deletions

View File

@@ -12,6 +12,7 @@
"github.com/alecthomas/kingpin"
"github.com/fatih/color"
"github.com/gorilla/mux"
"github.com/mattn/go-colorable"
"github.com/pkg/errors"
@@ -425,21 +426,21 @@ func (c *App) maybeRepositoryAction(act func(ctx context.Context, rep repo.Repos
defer gather.DumpStats(ctx)
if c.metricsListenAddr != "" {
mux := http.NewServeMux()
if err := initPrometheus(mux); err != nil {
m := mux.NewRouter()
if err := initPrometheus(m); err != nil {
return errors.Wrap(err, "unable to initialize prometheus.")
}
if c.enablePProf {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
m.HandleFunc("/debug/pprof/", pprof.Index)
m.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
m.HandleFunc("/debug/pprof/profile", pprof.Profile)
m.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
m.HandleFunc("/debug/pprof/trace", pprof.Trace)
}
log(ctx).Infof("starting prometheus metrics on %v", c.metricsListenAddr)
go http.ListenAndServe(c.metricsListenAddr, mux) // nolint:errcheck
go http.ListenAndServe(c.metricsListenAddr, m) // nolint:errcheck
}
memtrack.Dump(ctx, "before openRepository")

View File

@@ -14,6 +14,7 @@
"time"
"contrib.go.opencensus.io/exporter/prometheus"
"github.com/gorilla/mux"
"github.com/pkg/errors"
prom "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
@@ -30,10 +31,11 @@ type commandServerStart struct {
co connectOptions
serverStartHTMLPath string
serverStartUI bool
serverStartUI bool
serverStartLegacyRepositoryAPI bool
serverStartGRPC bool
serverStartControlAPI bool
serverStartRefreshInterval time.Duration
serverStartInsecure bool
@@ -63,6 +65,8 @@ type commandServerStart struct {
logServerRequests bool
disableCSRFTokenChecks bool // disable CSRF token checks - used for development/debugging only
sf serverFlags
svc advancedAppServices
out textOutput
@@ -75,6 +79,7 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent
cmd.Flag("legacy-api", "Start the legacy server API").Default("true").BoolVar(&c.serverStartLegacyRepositoryAPI)
cmd.Flag("grpc", "Start the GRPC server").Default("true").BoolVar(&c.serverStartGRPC)
cmd.Flag("control-api", "Start the control API").Default("true").BoolVar(&c.serverStartControlAPI)
cmd.Flag("refresh-interval", "Frequency for refreshing repository status").Default("300s").DurationVar(&c.serverStartRefreshInterval)
cmd.Flag("insecure", "Allow insecure configurations (do not use in production)").Hidden().BoolVar(&c.serverStartInsecure)
@@ -104,6 +109,7 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent
cmd.Flag("ui-preferences-file", "Path to JSON file storing UI preferences").StringVar(&c.uiPreferencesFile)
cmd.Flag("log-server-requests", "Log server requests").Hidden().BoolVar(&c.logServerRequests)
cmd.Flag("disable-csrf-token-checks", "Disable CSRF token").Hidden().BoolVar(&c.disableCSRFTokenChecks)
c.sf.setup(cmd)
c.co.setup(cmd)
@@ -116,7 +122,6 @@ func (c *commandServerStart) setup(svc advancedAppServices, parent commandParent
}))
}
// nolint:funlen
func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error {
authn, err := c.getAuthenticator(ctx)
if err != nil {
@@ -142,6 +147,8 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error
PasswordPersist: c.svc.passwordPersistenceStrategy(),
UIPreferencesFile: uiPreferencesFile,
UITitlePrefix: c.uiTitlePrefix,
DisableCSRFTokenChecks: c.disableCSRFTokenChecks,
})
if err != nil {
return errors.Wrap(err, "unable to initialize server")
@@ -155,16 +162,9 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error
return errors.Wrap(err, "error connecting to repository")
}
mux := http.NewServeMux()
m := mux.NewRouter()
mux.Handle("/api/", srv.APIHandlers(c.serverStartLegacyRepositoryAPI))
if c.serverStartHTMLPath != "" {
fileServer := srv.ServeStaticFiles(http.Dir(c.serverStartHTMLPath))
mux.Handle("/", fileServer)
} else if c.serverStartUI {
mux.Handle("/", srv.ServeStaticFiles(server.AssetFile()))
}
c.setupHandlers(srv, m)
httpServer := &http.Server{
Addr: stripProtocol(c.sf.serverAddress),
@@ -185,11 +185,11 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error
// init prometheus after adding interceptors that require credentials, so that this
// handler can be called without auth
if err = initPrometheus(mux); err != nil {
if err = initPrometheus(m); err != nil {
return errors.Wrap(err, "error initializing Prometheus")
}
var handler http.Handler = mux
var handler http.Handler = m
if c.serverStartGRPC {
handler = srv.GRPCRouterHandler(handler)
@@ -222,7 +222,27 @@ func (c *commandServerStart) run(ctx context.Context, rep repo.Repository) error
return errors.Wrap(srv.SetRepository(ctx, nil), "error setting active repository")
}
func initPrometheus(mux *http.ServeMux) error {
func (c *commandServerStart) setupHandlers(srv *server.Server, m *mux.Router) {
if c.serverStartLegacyRepositoryAPI {
srv.SetupRepositoryAPIHandlers(m)
}
if c.serverStartControlAPI {
srv.SetupControlAPIHandlers(m)
}
if c.serverStartUI {
srv.SetupHTMLUIAPIHandlers(m)
if c.serverStartHTMLPath != "" {
srv.ServeStaticFiles(m, http.Dir(c.serverStartHTMLPath))
} else {
srv.ServeStaticFiles(m, server.AssetFile())
}
}
}
func initPrometheus(m *mux.Router) error {
reg := prom.NewRegistry()
if err := reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})); err != nil {
return errors.Wrap(err, "error registering process collector")
@@ -239,7 +259,7 @@ func initPrometheus(mux *http.ServeMux) error {
return errors.Wrap(err, "unable to initialize prometheus exporter")
}
mux.Handle("/metrics", pe)
m.Handle("/metrics", pe)
return nil
}

View File

@@ -8,6 +8,8 @@
"io"
"net/http"
"net/http/cookiejar"
"regexp"
"strings"
"github.com/pkg/errors"
@@ -18,34 +20,69 @@
var log = logging.Module("client")
// CSRFTokenHeader is the name of CSRF token header that must be sent for most API calls.
// nolint:gosec
const CSRFTokenHeader = "X-Kopia-Csrf-Token"
// KopiaAPIClient provides helper methods for communicating with Kopia API server.
type KopiaAPIClient struct {
BaseURL string
HTTPClient *http.Client
CSRFToken string
}
// Get is a helper that performs HTTP GET on a URL with the specified suffix and decodes the response
// onto respPayload which must be a pointer to byte slice or JSON-serializable structure.
func (c *KopiaAPIClient) Get(ctx context.Context, urlSuffix string, onNotFound error, respPayload interface{}) error {
return c.runRequest(ctx, http.MethodGet, c.BaseURL+urlSuffix, onNotFound, nil, respPayload)
return c.runRequest(ctx, http.MethodGet, c.actualURL(urlSuffix), onNotFound, nil, respPayload)
}
// Post is a helper that performs HTTP POST on a URL with the specified body from reqPayload and decodes the response
// onto respPayload which must be a pointer to byte slice or JSON-serializable structure.
func (c *KopiaAPIClient) Post(ctx context.Context, urlSuffix string, reqPayload, respPayload interface{}) error {
return c.runRequest(ctx, http.MethodPost, c.BaseURL+urlSuffix, nil, reqPayload, respPayload)
return c.runRequest(ctx, http.MethodPost, c.actualURL(urlSuffix), nil, reqPayload, respPayload)
}
// Put is a helper that performs HTTP PUT on a URL with the specified body from reqPayload and decodes the response
// onto respPayload which must be a pointer to byte slice or JSON-serializable structure.
func (c *KopiaAPIClient) Put(ctx context.Context, urlSuffix string, reqPayload, respPayload interface{}) error {
return c.runRequest(ctx, http.MethodPut, c.BaseURL+urlSuffix, nil, reqPayload, respPayload)
return c.runRequest(ctx, http.MethodPut, c.actualURL(urlSuffix), nil, reqPayload, respPayload)
}
// Delete is a helper that performs HTTP DELETE on a URL with the specified body from reqPayload and decodes the response
// onto respPayload which must be a pointer to byte slice or JSON-serializable structure.
func (c *KopiaAPIClient) Delete(ctx context.Context, urlSuffix string, onNotFound error, reqPayload, respPayload interface{}) error {
return c.runRequest(ctx, http.MethodDelete, c.BaseURL+urlSuffix, onNotFound, reqPayload, respPayload)
return c.runRequest(ctx, http.MethodDelete, c.actualURL(urlSuffix), onNotFound, reqPayload, respPayload)
}
// FetchCSRFTokenForTesting fetches the CSRF token and session cookie for use when making subsequent calls to the API.
// This simulates the browser behavior of downloading the "/" and is required to call the UI-only methods.
func (c *KopiaAPIClient) FetchCSRFTokenForTesting(ctx context.Context) error {
var b []byte
if err := c.Get(ctx, "/", nil, &b); err != nil {
return err
}
re := regexp.MustCompile(`<meta name="kopia-csrf-token" content="(.*)" />`)
match := re.FindSubmatch(b)
if match == nil {
return errors.Errorf("CSRF token not found")
}
c.CSRFToken = string(match[1])
return nil
}
func (c *KopiaAPIClient) actualURL(suffix string) string {
if strings.HasPrefix(suffix, "/") {
return c.BaseURL + suffix
}
return c.BaseURL + "/api/v1/" + suffix
}
func (c *KopiaAPIClient) runRequest(ctx context.Context, method, url string, notFoundError error, reqPayload, respPayload interface{}) error {
@@ -59,6 +96,10 @@ func (c *KopiaAPIClient) runRequest(ctx context.Context, method, url string, not
return errors.Wrap(err, "error creating request")
}
if c.CSRFToken != "" {
req.Header.Add(CSRFTokenHeader, c.CSRFToken)
}
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
@@ -165,11 +206,12 @@ func NewKopiaAPIClient(options Options) (*KopiaAPIClient, error) {
}
return &KopiaAPIClient{
options.BaseURL + "/api/v1/",
options.BaseURL,
&http.Client{
Jar: cj,
Transport: transport,
},
"",
}, nil
}

View File

@@ -23,6 +23,7 @@ func TestCLIAPI(t *testing.T) {
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
resp := &serverapi.CLIInfo{}
require.NoError(t, cli.Get(ctx, "cli", nil, resp))

View File

@@ -23,6 +23,7 @@ func TestPathsAPI(t *testing.T) {
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
dir0 := testutil.TempDirectory(t)

View File

@@ -29,6 +29,7 @@ func TestPolicies(t *testing.T) {
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
dir0 := testutil.TempDirectory(t)
si0 := localSource(env, dir0)

View File

@@ -30,6 +30,7 @@ func TestSnapshotCounters(t *testing.T) {
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
dir := testutil.TempDirectory(t)
si := localSource(env, dir)
@@ -103,6 +104,7 @@ func TestSourceRefreshesAfterPolicy(t *testing.T) {
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
dir := testutil.TempDirectory(t)
si := localSource(env, dir)

View File

@@ -23,6 +23,8 @@ func TestUIPreferences(t *testing.T) {
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
var p, p2 serverapi.UIPreferences
require.NoError(t, cli.Get(ctx, "ui-preferences", nil, &p))

View File

@@ -45,6 +45,13 @@
kopiaAuthCookieIssuer = "kopia-server"
)
type csrfTokenOption int
const (
csrfTokenRequired csrfTokenOption = 1 + iota
csrfTokenNotRequired
)
type apiRequestFunc func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError)
// Server exposes simple HTTP API for programmatically accessing Kopia features.
@@ -72,92 +79,81 @@ type Server struct {
grpcServerState
}
// APIHandlers handles API requests.
func (s *Server) APIHandlers(legacyAPI bool) http.Handler {
m := mux.NewRouter()
// SetupHTMLUIAPIHandlers registers API requests required by the HTMLUI.
func (s *Server) SetupHTMLUIAPIHandlers(m *mux.Router) {
// sources
m.HandleFunc("/api/v1/sources", s.handleAPI(requireUIUser, s.handleSourcesList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/sources", s.handleAPI(requireUIUser, s.handleSourcesCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/sources/upload", s.handleAPI(requireUIUser, s.handleUpload)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/sources/cancel", s.handleAPI(requireUIUser, s.handleCancel)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/sources", s.handleUI(s.handleSourcesList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/sources", s.handleUI(s.handleSourcesCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/sources/upload", s.handleUI(s.handleUpload)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/sources/cancel", s.handleUI(s.handleCancel)).Methods(http.MethodPost)
// snapshots
m.HandleFunc("/api/v1/snapshots", s.handleAPI(requireUIUser, s.handleSnapshotList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyPut)).Methods(http.MethodPut)
m.HandleFunc("/api/v1/policy", s.handleAPI(requireUIUser, s.handlePolicyDelete)).Methods(http.MethodDelete)
m.HandleFunc("/api/v1/policy/resolve", s.handleAPI(requireUIUser, s.handlePolicyResolve)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/snapshots", s.handleUI(s.handleSnapshotList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyPut)).Methods(http.MethodPut)
m.HandleFunc("/api/v1/policy", s.handleUI(s.handlePolicyDelete)).Methods(http.MethodDelete)
m.HandleFunc("/api/v1/policy/resolve", s.handleUI(s.handlePolicyResolve)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/policies", s.handleUI(s.handlePolicyList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/refresh", s.handleUI(s.handleRefresh)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/objects/{objectID}", s.requireAuth(csrfTokenNotRequired, s.handleObjectGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/restore", s.handleUI(s.handleRestore)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/estimate", s.handleUI(s.handleEstimate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/paths/resolve", s.handleUI(s.handlePathResolve)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/cli", s.handleUI(s.handleCLIInfo)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/status", s.handleUI(s.handleRepoStatus)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/sync", s.handleUI(s.handleRepoSync)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/connect", s.handleUIPossiblyNotConnected(s.handleRepoConnect)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/exists", s.handleUIPossiblyNotConnected(s.handleRepoExists)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/create", s.handleUIPossiblyNotConnected(s.handleRepoCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/description", s.handleUI(s.handleRepoSetDescription)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/disconnect", s.handleUI(s.handleRepoDisconnect)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/algorithms", s.handleUIPossiblyNotConnected(s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/throttle", s.handleUI(s.handleRepoGetThrottle)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/throttle", s.handleUI(s.handleRepoSetThrottle)).Methods(http.MethodPut)
m.HandleFunc("/api/v1/policies", s.handleAPI(requireUIUser, s.handlePolicyList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/mounts", s.handleUI(s.handleMountCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleUI(s.handleMountDelete)).Methods(http.MethodDelete)
m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleUI(s.handleMountGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/mounts", s.handleUI(s.handleMountList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/refresh", s.handleAPI(anyAuthenticatedUser, s.handleRefresh)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/current-user", s.handleUIPossiblyNotConnected(s.handleCurrentUser)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/ui-preferences", s.handleUIPossiblyNotConnected(s.handleGetUIPreferences)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/ui-preferences", s.handleUIPossiblyNotConnected(s.handleSetUIPreferences)).Methods(http.MethodPut)
m.HandleFunc("/api/v1/objects/{objectID}", s.requireAuth(s.handleObjectGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/restore", s.handleAPI(requireUIUser, s.handleRestore)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/estimate", s.handleAPI(requireUIUser, s.handleEstimate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/tasks-summary", s.handleUI(s.handleTaskSummary)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks", s.handleUI(s.handleTaskList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks/{taskID}", s.handleUI(s.handleTaskInfo)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks/{taskID}/logs", s.handleUI(s.handleTaskLogs)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks/{taskID}/cancel", s.handleUI(s.handleTaskCancel)).Methods(http.MethodPost)
}
// path APIs
m.HandleFunc("/api/v1/paths/resolve", s.handleAPI(requireUIUser, s.handlePathResolve)).Methods(http.MethodPost)
// SetupRepositoryAPIHandlers registers HTTP repository API handlers.
func (s *Server) SetupRepositoryAPIHandlers(m *mux.Router) {
m.HandleFunc("/api/v1/flush", s.handleRepositoryAPI(anyAuthenticatedUser, s.handleFlush)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/parameters", s.handleRepositoryAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet)
// path APIs
m.HandleFunc("/api/v1/cli", s.handleAPI(requireUIUser, s.handleCLIInfo)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentInfo)).Methods(http.MethodGet).Queries("info", "1")
m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/contents/{contentID}", s.handleRepositoryAPI(requireContentAccess(auth.AccessLevelAppend), s.handleContentPut)).Methods(http.MethodPut)
// methods that can be called by any authenticated user (UI or remote user).
m.HandleFunc("/api/v1/flush", s.handleAPI(anyAuthenticatedUser, s.handleFlush)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/status", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoStatus)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/sync", s.handleAPI(anyAuthenticatedUser, s.handleRepoSync)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestDelete)).Methods(http.MethodDelete)
m.HandleFunc("/api/v1/manifests", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/manifests", s.handleRepositoryAPI(handlerWillCheckAuthorization, s.handleManifestList)).Methods(http.MethodGet)
}
m.HandleFunc("/api/v1/repo/connect", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoConnect)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/exists", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoExists)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/create", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/description", s.handleAPI(requireUIUser, s.handleRepoSetDescription)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/disconnect", s.handleAPI(requireUIUser, s.handleRepoDisconnect)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/repo/algorithms", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleRepoSupportedAlgorithms)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoGetThrottle)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/repo/throttle", s.handleAPI(requireUIUser, s.handleRepoSetThrottle)).Methods(http.MethodPut)
if legacyAPI {
m.HandleFunc("/api/v1/repo/parameters", s.handleAPI(anyAuthenticatedUser, s.handleRepoParameters)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentInfo)).Methods(http.MethodGet).Queries("info", "1")
m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelRead), s.handleContentGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/contents/{contentID}", s.handleAPI(requireContentAccess(auth.AccessLevelAppend), s.handleContentPut)).Methods(http.MethodPut)
m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/manifests/{manifestID}", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestDelete)).Methods(http.MethodDelete)
m.HandleFunc("/api/v1/manifests", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/manifests", s.handleAPI(handlerWillCheckAuthorization, s.handleManifestList)).Methods(http.MethodGet)
}
m.HandleFunc("/api/v1/mounts", s.handleAPI(requireUIUser, s.handleMountCreate)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleAPI(requireUIUser, s.handleMountDelete)).Methods(http.MethodDelete)
m.HandleFunc("/api/v1/mounts/{rootObjectID}", s.handleAPI(requireUIUser, s.handleMountGet)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/mounts", s.handleAPI(requireUIUser, s.handleMountList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/current-user", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleCurrentUser)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/ui-preferences", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleGetUIPreferences)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/ui-preferences", s.handleAPIPossiblyNotConnected(requireUIUser, s.handleSetUIPreferences)).Methods(http.MethodPut)
m.HandleFunc("/api/v1/tasks-summary", s.handleAPI(requireUIUser, s.handleTaskSummary)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks", s.handleAPI(requireUIUser, s.handleTaskList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks/{taskID}", s.handleAPI(requireUIUser, s.handleTaskInfo)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks/{taskID}/logs", s.handleAPI(requireUIUser, s.handleTaskLogs)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/tasks/{taskID}/cancel", s.handleAPI(requireUIUser, s.handleTaskCancel)).Methods(http.MethodPost)
// server control API, requires authentication as `server-control`.
m.HandleFunc("/api/v1/control/status", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleRepoStatus)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/control/sources", s.handleAPI(requireServerControlUser, s.handleSourcesList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/control/flush", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleFlush)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/refresh", s.handleAPIPossiblyNotConnected(requireServerControlUser, s.handleRefresh)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/shutdown", s.handleAPI(requireServerControlUser, s.handleShutdown)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/trigger-snapshot", s.handleAPI(requireServerControlUser, s.handleUpload)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/cancel-snapshot", s.handleAPI(requireServerControlUser, s.handleCancel)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/pause-source", s.handleAPI(requireServerControlUser, s.handlePause)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/resume-source", s.handleAPI(requireServerControlUser, s.handleResume)).Methods(http.MethodPost)
return m
// SetupControlAPIHandlers registers control API handlers.
func (s *Server) SetupControlAPIHandlers(m *mux.Router) {
// server control API, requires authentication as `server-control` and no CSRF token.
m.HandleFunc("/api/v1/control/sources", s.handleServerControlAPI(s.handleSourcesList)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/control/status", s.handleServerControlAPIPossiblyNotConnected(s.handleRepoStatus)).Methods(http.MethodGet)
m.HandleFunc("/api/v1/control/flush", s.handleServerControlAPI(s.handleFlush)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/refresh", s.handleServerControlAPI(s.handleRefresh)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/shutdown", s.handleServerControlAPIPossiblyNotConnected(s.handleShutdown)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/trigger-snapshot", s.handleServerControlAPI(s.handleUpload)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/cancel-snapshot", s.handleServerControlAPI(s.handleCancel)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/pause-source", s.handleServerControlAPI(s.handlePause)).Methods(http.MethodPost)
m.HandleFunc("/api/v1/control/resume-source", s.handleServerControlAPI(s.handleResume)).Methods(http.MethodPost)
}
func (s *Server) isAuthenticated(w http.ResponseWriter, r *http.Request) bool {
@@ -198,6 +194,7 @@ func (s *Server) isAuthenticated(w http.ResponseWriter, r *http.Request) bool {
Name: kopiaAuthCookie,
Value: ac,
Expires: now.Add(kopiaAuthCookieTTL),
Path: "/",
})
}
@@ -233,12 +230,19 @@ func (s *Server) generateShortTermAuthCookie(username string, now time.Time) (st
}).SignedString(s.authCookieSigningKey)
}
func (s *Server) requireAuth(f http.HandlerFunc) http.HandlerFunc {
func (s *Server) requireAuth(checkCSRFToken csrfTokenOption, f http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !s.isAuthenticated(w, r) {
return
}
if checkCSRFToken == csrfTokenRequired {
if !s.validateCSRFToken(r) {
http.Error(w, "Invalid or missing CSRF token.\n", http.StatusUnauthorized)
return
}
}
f(w, r)
}
}
@@ -257,8 +261,8 @@ func (s *Server) httpAuthorizationInfo(ctx context.Context, r *http.Request) aut
type isAuthorizedFunc func(s *Server, r *http.Request) bool
func (s *Server) handleAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc {
return s.handleAPIPossiblyNotConnected(isAuthorized, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
func (s *Server) handleServerControlAPI(f apiRequestFunc) http.HandlerFunc {
return s.handleServerControlAPIPossiblyNotConnected(func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep == nil {
return nil, requestError(serverapi.ErrorNotConnected, "not connected")
}
@@ -267,8 +271,38 @@ func (s *Server) handleAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http
})
}
func (s *Server) handleAPIPossiblyNotConnected(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc {
return s.requireAuth(func(w http.ResponseWriter, r *http.Request) {
func (s *Server) handleServerControlAPIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc {
return s.handleRequestPossiblyNotConnected(requireServerControlUser, csrfTokenNotRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
return f(ctx, r, body)
})
}
func (s *Server) handleRepositoryAPI(isAuthorized isAuthorizedFunc, f apiRequestFunc) http.HandlerFunc {
return s.handleRequestPossiblyNotConnected(isAuthorized, csrfTokenNotRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep == nil {
return nil, requestError(serverapi.ErrorNotConnected, "not connected")
}
return f(ctx, r, body)
})
}
func (s *Server) handleUI(f apiRequestFunc) http.HandlerFunc {
return s.handleRequestPossiblyNotConnected(requireUIUser, csrfTokenRequired, func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep == nil {
return nil, requestError(serverapi.ErrorNotConnected, "not connected")
}
return f(ctx, r, body)
})
}
func (s *Server) handleUIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc {
return s.handleRequestPossiblyNotConnected(requireUIUser, csrfTokenRequired, f)
}
func (s *Server) handleRequestPossiblyNotConnected(isAuthorized isAuthorizedFunc, checkCSRFToken csrfTokenOption, f apiRequestFunc) http.HandlerFunc {
return s.requireAuth(checkCSRFToken, func(w http.ResponseWriter, r *http.Request) {
// we must pre-read request body before acquiring the lock as it sometimes leads to deadlock
// in HTTP/2 server.
// See https://github.com/golang/go/issues/40816
@@ -682,11 +716,18 @@ func (s *Server) isKnownUIRoute(path string) bool {
strings.HasPrefix(path, "/repo")
}
func (s *Server) patchIndexBytes(b []byte) []byte {
func (s *Server) patchIndexBytes(sessionID string, b []byte) []byte {
if s.options.UITitlePrefix != "" {
b = bytes.ReplaceAll(b, []byte("<title>"), []byte("<title>"+html.EscapeString(s.options.UITitlePrefix)))
}
csrfToken := s.generateCSRFToken(sessionID)
// insert <meta name="kopia-csrf-token" content="..." /> just before closing head tag.
b = bytes.ReplaceAll(b,
[]byte(`</head>`),
[]byte(`<meta name="kopia-csrf-token" content="`+csrfToken+`" /></head>`))
return b
}
@@ -706,14 +747,14 @@ func maybeReadIndexBytes(fs http.FileSystem) []byte {
return rd
}
// ServeStaticFiles returns HTTP handler that serves static files and dynamically patches index.html to embed CSRF token, etc.
func (s *Server) ServeStaticFiles(fs http.FileSystem) http.Handler {
// ServeStaticFiles configures HTTP handler that serves static files and dynamically patches index.html to embed CSRF token, etc.
func (s *Server) ServeStaticFiles(m *mux.Router, fs http.FileSystem) {
h := http.FileServer(fs)
// read bytes from 'index.html'.
indexBytes := maybeReadIndexBytes(fs)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.isKnownUIRoute(r.URL.Path) {
r2 := new(http.Request)
*r2 = *r
@@ -733,7 +774,22 @@ func (s *Server) ServeStaticFiles(fs http.FileSystem) http.Handler {
}
if r.URL.Path == "/" && indexBytes != nil {
http.ServeContent(w, r, "/", clock.Now(), bytes.NewReader(s.patchIndexBytes(indexBytes)))
var sessionID string
if cookie, err := r.Cookie(kopiaSessionCookie); err == nil {
// already in a session, likely a new tab was opened
sessionID = cookie.Value
} else {
sessionID = uuid.NewString()
http.SetCookie(w, &http.Cookie{
Name: kopiaSessionCookie,
Value: sessionID,
Path: "/",
})
}
http.ServeContent(w, r, "/", clock.Now(), bytes.NewReader(s.patchIndexBytes(sessionID, indexBytes)))
return
}
@@ -752,7 +808,7 @@ type Options struct {
PasswordPersist passwordpersist.Strategy
AuthCookieSigningKey string
LogRequests bool
UIUser string // name of the user allowed to access the UI
UIUser string // name of the user allowed to access the UI API
UIPreferencesFile string // name of the JSON file storing UI preferences
ServerControlUser string // name of the user allowed to access the server control API
DisableCSRFTokenChecks bool

View File

@@ -2,16 +2,72 @@
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"io"
"net/http"
"github.com/kopia/kopia/internal/apiclient"
"github.com/kopia/kopia/internal/auth"
)
// kopiaSessionCookie is the name of the session cookie that Kopia server will generate for all
// UI sessions.
const kopiaSessionCookie = "Kopia-Session-Cookie"
func (s *Server) generateCSRFToken(sessionID string) string {
h := hmac.New(sha256.New, s.authCookieSigningKey)
if _, err := io.WriteString(h, sessionID); err != nil {
panic("io.WriteString() failed: " + err.Error())
}
return hex.EncodeToString(h.Sum(nil))
}
func (s *Server) validateCSRFToken(r *http.Request) bool {
if s.options.DisableCSRFTokenChecks {
return true
}
ctx := r.Context()
path := r.URL.Path
sessionCookie, err := r.Cookie(kopiaSessionCookie)
if err != nil {
log(ctx).Warnf("missing or invalid session cookie for %q: %v", path, err)
return false
}
validToken := s.generateCSRFToken(sessionCookie.Value)
token := r.Header.Get(apiclient.CSRFTokenHeader)
if token == "" {
log(ctx).Warnf("missing CSRF token for %v", path)
return false
}
if subtle.ConstantTimeCompare([]byte(validToken), []byte(token)) == 1 {
return true
}
log(ctx).Warnf("got invalid CSRF token for %v: %v, want %v, session %v", path, token, validToken, sessionCookie.Value)
return false
}
func requireUIUser(s *Server, r *http.Request) bool {
if s.authenticator == nil {
return true
}
if s.options.UIUser == "" {
return false
}
user, _, _ := r.BasicAuth()
return user == s.options.UIUser
@@ -22,6 +78,10 @@ func requireServerControlUser(s *Server, r *http.Request) bool {
return true
}
if s.options.ServerControlUser == "" {
return false
}
user, _, _ := r.BasicAuth()
return user == s.options.ServerControlUser

View File

@@ -0,0 +1,101 @@
package server
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/kopia/kopia/internal/apiclient"
)
func TestGenerateCSRFToken(t *testing.T) {
s1 := &Server{
authCookieSigningKey: []byte("some-key"),
}
s2 := &Server{
authCookieSigningKey: []byte("some-other-key"),
}
cases := []struct {
srv *Server
session string
wantToken string
}{
// CSRF token is a stable function of session ID and per-server so we can hardcode it
{s1, "session1", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0"},
{s2, "session1", "7fd10608493e844581247d44e61de56b80df83ecdae49891f150823f73524ef7"},
{s1, "session2", "e3aeba64243485ac4664e27445f48711b987c3bf5c7e58d1b89eb1e2722fedcd"},
{s2, "session2", "714a124df0f6b6e79500fb06a900e7870a94f50b6d1e532c92a0abc0c63146f8"},
}
for _, tc := range cases {
require.Equal(t, tc.wantToken, tc.srv.generateCSRFToken(tc.session))
}
}
func TestValidateCSRFToken(t *testing.T) {
s1 := &Server{
authCookieSigningKey: []byte("some-key"),
}
s2 := &Server{
authCookieSigningKey: []byte("some-other-key"),
}
s3 := &Server{
options: Options{
DisableCSRFTokenChecks: true,
},
}
cases := []struct {
srv *Server
session string
token string
valid bool
}{
// valid
{s1, "session1", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0", true},
{s2, "session1", "7fd10608493e844581247d44e61de56b80df83ecdae49891f150823f73524ef7", true},
{s1, "session2", "e3aeba64243485ac4664e27445f48711b987c3bf5c7e58d1b89eb1e2722fedcd", true},
{s2, "session2", "714a124df0f6b6e79500fb06a900e7870a94f50b6d1e532c92a0abc0c63146f8", true},
// invalid cases
{s1, "", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0", false}, // missing session ID
{s1, "session2", "", false}, // missing token
{s2, "session2", "invalid", false}, // invalid token
// token is invalid but ignored, since 's3' does not validate tokens.
{s3, "", "557c279a9203afbd5e1edd8a3b091fbcaf699841cd95058954e11886f0a3e6d0", true},
{s3, "session2", "", true},
{s3, "session2", "invalid-token", true},
}
ctx := context.Background()
for i, tc := range cases {
tc := tc
t.Run(fmt.Sprintf("case-%v", i), func(t *testing.T) {
req, err := http.NewRequestWithContext(ctx, "GET", "/somepath", http.NoBody)
require.NoError(t, err)
if tc.session != "" {
req.AddCookie(&http.Cookie{
Name: "Kopia-Session-Cookie",
Value: tc.session,
})
}
if tc.token != "" {
req.Header.Add(apiclient.CSRFTokenHeader, tc.token)
}
require.Equal(t, tc.valid, tc.srv.validateCSRFToken(req))
})
}
}

View File

@@ -12,6 +12,7 @@
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
@@ -71,7 +72,13 @@ func startServer(t *testing.T, env *repotesting.Environment, tls bool) *repo.API
asi := &repo.APIServerInfo{}
hs := httptest.NewUnstartedServer(s.GRPCRouterHandler(s.APIHandlers(true)))
m := mux.NewRouter()
s.SetupHTMLUIAPIHandlers(m)
s.SetupRepositoryAPIHandlers(m)
s.SetupControlAPIHandlers(m)
s.ServeStaticFiles(m, server.AssetFile())
hs := httptest.NewUnstartedServer(s.GRPCRouterHandler(m))
if tls {
hs.EnableHTTP2 = true
hs.StartTLS()
@@ -131,6 +138,7 @@ func TestGPRServer_AuthenticationError(t *testing.T) {
}
}
// nolint:gocyclo
func TestServerUIAccessDeniedToRemoteUser(t *testing.T) {
ctx, env := repotesting.NewEnvironment(t, repotesting.FormatNotImportant)
si := startServer(t, env, true)
@@ -141,6 +149,14 @@ func TestServerUIAccessDeniedToRemoteUser(t *testing.T) {
Username: testUsername + "@" + testHostname,
Password: testPassword,
})
require.NoError(t, err)
uiUserWithoutCSRFToken, err := apiclient.NewKopiaAPIClient(apiclient.Options{
BaseURL: si.BaseURL,
TrustedServerCertificateFingerprint: si.TrustedServerCertificateFingerprint,
Username: testUIUsername,
Password: testUIPassword,
})
require.NoError(t, err)
@@ -153,7 +169,10 @@ func TestServerUIAccessDeniedToRemoteUser(t *testing.T) {
require.NoError(t, err)
// examples of URLs and expected statuses returned when UI user calls them, but which must return 403 when
require.NoError(t, uiUserClient.FetchCSRFTokenForTesting(ctx))
// do not call uiUserWithoutCSRFToken.FetchCSRFTokenForTesting()
// examples of URLs and expected statuses returned when UI user calls them, but which must return 401 due to missing CSRF token
// remote user calls them.
getUrls := map[string]int{
"mounts": http.StatusOK,
@@ -171,8 +190,34 @@ func TestServerUIAccessDeniedToRemoteUser(t *testing.T) {
t.Run(urlSuffix, func(t *testing.T) {
var hsr apiclient.HTTPStatusError
if err := remoteUserClient.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || hsr.HTTPStatusCode != http.StatusForbidden {
t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, http.StatusForbidden)
wantFailure := http.StatusUnauthorized // 401
if urlSuffix == "objects/abcd" {
// this is a special one that does not require CSRF token but will still fail with 403
wantFailure = http.StatusForbidden
}
if err := remoteUserClient.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || (hsr.HTTPStatusCode != wantFailure) {
t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, wantFailure)
}
if wantStatus == http.StatusOK {
if err := uiUserClient.Get(ctx, urlSuffix, nil, nil); err != nil {
t.Fatalf("expected success, got %v", err)
}
} else if err := uiUserClient.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || hsr.HTTPStatusCode != wantStatus {
t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, wantStatus)
}
// objects/abcd does not require CSRF token so will fail with 404 instead of 403.
// This is fine since this is a side-effect-free GET method so same-origin policy
// will protect access to data.
if urlSuffix == "objects/abcd" {
wantFailure = http.StatusNotFound
}
if err := uiUserWithoutCSRFToken.Get(ctx, urlSuffix, nil, nil); !errors.As(err, &hsr) || (hsr.HTTPStatusCode != wantFailure) {
t.Fatalf("error returned expected to be HTTPStatusError %v, want %v", hsr.HTTPStatusCode, wantFailure)
}
if wantStatus == http.StatusOK {
@@ -213,7 +258,6 @@ func remoteRepositoryTest(ctx context.Context, t *testing.T, rep repo.Repository
}, func(ctx context.Context, w repo.RepositoryWriter) error {
mustGetObjectNotFound(ctx, t, w, "abcd")
mustGetManifestNotFound(ctx, t, w, "mnosuchmanifest")
mustManifestNotFound(t, w.DeleteManifest(ctx, manifestID2))
mustListSnapshotCount(ctx, t, w, 0)
result = mustWriteObject(ctx, t, w, written)

View File

@@ -98,6 +98,7 @@ func TestServerStart(t *testing.T) {
LogRequests: true,
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
controlClient, err := apiclient.NewKopiaAPIClient(apiclient.Options{
BaseURL: sp.baseURL,
@@ -113,8 +114,7 @@ func TestServerStart(t *testing.T) {
waitUntilServerStarted(ctx, t, controlClient)
verifyUIServedWithCorrectTitle(t, cli, sp)
st := verifyServerConnected(t, controlClient, true)
require.Equal(t, "filesystem", st.Storage)
verifyServerConnected(t, controlClient, true)
limits, err := serverapi.GetThrottlingLimits(ctx, cli)
require.NoError(t, err)
@@ -234,6 +234,7 @@ func TestServerCreateAndConnectViaAPI(t *testing.T) {
TrustedServerCertificateFingerprint: sp.sha256Fingerprint,
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
controlClient, err := apiclient.NewKopiaAPIClient(apiclient.Options{
BaseURL: sp.baseURL,
@@ -306,20 +307,11 @@ func TestConnectToExistingRepositoryViaAPI(t *testing.T) {
"--override-hostname=fake-hostname", "--override-username=fake-username")
t.Logf("detected server parameters %#v", sp)
cli, err := apiclient.NewKopiaAPIClient(apiclient.Options{
BaseURL: sp.baseURL,
Username: "kopia",
Password: sp.password,
TrustedServerCertificateFingerprint: sp.sha256Fingerprint,
})
require.NoError(t, err)
controlClient, err := apiclient.NewKopiaAPIClient(apiclient.Options{
BaseURL: sp.baseURL,
Username: "server-control",
Password: sp.serverControlPassword,
TrustedServerCertificateFingerprint: sp.sha256Fingerprint,
LogRequests: true,
})
require.NoError(t, err)
@@ -328,6 +320,15 @@ func TestConnectToExistingRepositoryViaAPI(t *testing.T) {
waitUntilServerStarted(ctx, t, controlClient)
verifyServerConnected(t, controlClient, false)
cli, err := apiclient.NewKopiaAPIClient(apiclient.Options{
BaseURL: sp.baseURL,
Username: "kopia",
Password: sp.password,
TrustedServerCertificateFingerprint: sp.sha256Fingerprint,
})
require.NoError(t, err)
require.NoError(t, cli.FetchCSRFTokenForTesting(ctx))
if err = serverapi.ConnectToRepository(ctx, cli, &serverapi.ConnectRepositoryRequest{
Password: testenv.TestRepoPassword,
Storage: connInfo,
@@ -533,10 +534,21 @@ func verifyUIServedWithCorrectTitle(t *testing.T, cli *apiclient.KopiaAPIClient,
func waitUntilServerStarted(ctx context.Context, t *testing.T, cli *apiclient.KopiaAPIClient) {
t.Helper()
if err := retry.PeriodicallyNoValue(ctx, 1*time.Second, 180, "wait for server start", func() error {
require.NoError(t, retry.PeriodicallyNoValue(ctx, 1*time.Second, 180, "wait for server start", func() error {
_, err := serverapi.Status(testlogging.Context(t), cli)
return err
}, retry.Always); err != nil {
t.Fatalf("server failed to start")
}
}, func(err error) bool {
var hs apiclient.HTTPStatusError
if errors.As(err, &hs) {
switch hs.HTTPStatusCode {
case http.StatusBadRequest:
return false
case http.StatusForbidden:
return false
}
}
return true
}))
}