diff --git a/lib/oauthutil/oauthutil.go b/lib/oauthutil/oauthutil.go index 347faea95..0e152ab69 100644 --- a/lib/oauthutil/oauthutil.go +++ b/lib/oauthutil/oauthutil.go @@ -21,6 +21,7 @@ import ( "github.com/rclone/rclone/fs/config/configmap" "github.com/rclone/rclone/fs/fserrors" "github.com/rclone/rclone/fs/fshttp" + "github.com/rclone/rclone/fs/rc" "github.com/rclone/rclone/lib/random" "github.com/skratchdot/open-golang/open" "golang.org/x/oauth2" @@ -30,6 +31,11 @@ import ( var ( // templateString is the template used in the authorization webserver templateString string + + // oauthCancelFn stores the cancel function for the currently active OAuth flow + oauthCancelFn context.CancelFunc + // oauthCancelMu protects oauthCancelFn + oauthCancelMu sync.Mutex ) const ( @@ -770,6 +776,60 @@ func init() { fs.ConfigOAuth = ConfigOAuth } +func init() { + rc.Add(rc.Call{ + Path: "config/oauthstop", + Fn: rcOAuthStop, + Title: "Stop any running OAuth authentication server.", + Help: `Stops the OAuth authentication server if one is running. + +This can be used to recover from an interrupted OAuth flow without +restarting rclone. If no OAuth authentication is in progress, an error +is returned. +`, + }) +} + +func rcOAuthStop(ctx context.Context, in rc.Params) (out rc.Params, err error) { + oauthCancelMu.Lock() + defer oauthCancelMu.Unlock() + if oauthCancelFn == nil { + return nil, errors.New("no oauth authentication is in progress") + } + oauthCancelFn() + oauthCancelFn = nil + return nil, nil +} + +func init() { + rc.Add(rc.Call{ + Path: "config/oauthstatus", + Fn: rcOAuthStatus, + Title: "Get the status of the OAuth authentication server.", + Help: `Returns the current status of the OAuth authentication server. + +Returns a JSON object: +- status - "running" or "stopped" + +Eg + + { + "status": "running" + } +`, + }) +} + +func rcOAuthStatus(ctx context.Context, in rc.Params) (out rc.Params, err error) { + oauthCancelMu.Lock() + defer oauthCancelMu.Unlock() + status := "stopped" + if oauthCancelFn != nil { + status = "running" + } + return rc.Params{"status": status}, nil +} + // Return true if can run without a webserver and just entering a code func noWebserverNeeded(oauthConfig *Config) bool { return oauthConfig.RedirectURL == TitleBarRedirectURL @@ -854,8 +914,19 @@ func configSetup(ctx context.Context, id, name string, m configmap.Mapper, oauth if err != nil { return "", fmt.Errorf("failed to start auth webserver: %w", err) } + oauthCtx, cancel := context.WithCancel(ctx) + oauthCancelMu.Lock() + oauthCancelFn = cancel + oauthCancelMu.Unlock() + go server.Serve() - defer server.Stop() + defer func() { + oauthCancelMu.Lock() + oauthCancelFn = nil + oauthCancelMu.Unlock() + cancel() + server.Stop() + }() authURL = "http://" + bindAddress + "/auth?state=" + state if !authorizeNoAutoBrowser { @@ -873,7 +944,12 @@ func configSetup(ctx context.Context, id, name string, m configmap.Mapper, oauth // Read the code via the webserver fs.Logf(nil, "Waiting for code...\n") - auth := <-server.result + var auth *AuthResult + select { + case auth = <-server.result: + case <-oauthCtx.Done(): + return "", errors.New("oauth authentication was cancelled") + } if !auth.OK || auth.Code == "" { return "", auth } diff --git a/lib/oauthutil/rc_test.go b/lib/oauthutil/rc_test.go new file mode 100644 index 000000000..5f1aadae9 --- /dev/null +++ b/lib/oauthutil/rc_test.go @@ -0,0 +1,65 @@ +package oauthutil + +import ( + "context" + "testing" + + "github.com/rclone/rclone/fs/rc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRcOAuthStatus(t *testing.T) { + call := rc.Calls.Get("config/oauthstatus") + require.NotNil(t, call) + ctx := context.Background() + + // Status should be "stopped" when no OAuth is running + out, err := call.Fn(ctx, rc.Params{}) + require.NoError(t, err) + assert.Equal(t, "stopped", out["status"]) + + // Simulate an active OAuth flow + ctx2, cancel := context.WithCancel(ctx) + defer cancel() + oauthCancelMu.Lock() + oauthCancelFn = cancel + oauthCancelMu.Unlock() + defer func() { + oauthCancelMu.Lock() + oauthCancelFn = nil + oauthCancelMu.Unlock() + }() + + // Status should be "running" + out, err = call.Fn(ctx2, rc.Params{}) + require.NoError(t, err) + assert.Equal(t, "running", out["status"]) +} + +func TestRcOAuthStop(t *testing.T) { + call := rc.Calls.Get("config/oauthstop") + require.NotNil(t, call) + ctx := context.Background() + + // Stop should return error when no OAuth is running + _, err := call.Fn(ctx, rc.Params{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no oauth authentication is in progress") + + // Simulate an active OAuth flow + _, cancel := context.WithCancel(ctx) + oauthCancelMu.Lock() + oauthCancelFn = cancel + oauthCancelMu.Unlock() + + // Stop should succeed + out, err := call.Fn(ctx, rc.Params{}) + require.NoError(t, err) + assert.Nil(t, out) + + // Subsequent stop should return error + _, err = call.Fn(ctx, rc.Params{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no oauth authentication is in progress") +}