diff --git a/cmd/kopia/config.go b/cmd/kopia/config.go index 59b45667a..5a91cc8e7 100644 --- a/cmd/kopia/config.go +++ b/cmd/kopia/config.go @@ -2,7 +2,6 @@ import ( "encoding/hex" - "errors" "fmt" "io/ioutil" "log" @@ -112,8 +111,6 @@ func openVaultSpecifiedByFlag() (*vault.Vault, error) { return vault.Open(storage, creds) } -var errPasswordTooShort = errors.New("password too short") - func getVaultCredentials(isNew bool) (vault.Credentials, error) { if *key != "" { k, err := hex.DecodeString(*key) @@ -148,11 +145,6 @@ func getVaultCredentials(isNew bool) (vault.Credentials, error) { if isNew { for { p1, err := askPass("Enter password to create new vault: ") - if err == errPasswordTooShort { - fmt.Printf("Password too short, must be at least %v characters, you entered %v. Try again.", vault.MinPasswordLength, len(p1)) - fmt.Println() - continue - } if err != nil { return nil, err } @@ -177,16 +169,23 @@ func getVaultCredentials(isNew bool) (vault.Credentials, error) { } func askPass(prompt string) (string, error) { - b, err := speakeasy.Ask(prompt) - if err != nil { - return "", err + for { + b, err := speakeasy.Ask(prompt) + if err != nil { + return "", err + } + + p := string(b) + + if len(p) == 0 { + continue + } + + if len(p) >= vault.MinPasswordLength { + return p, nil + } + + fmt.Printf("Password too short, must be at least %v characters, you entered %v. Try again.", vault.MinPasswordLength, len(p)) + fmt.Println() } - - p := string(b) - - if len(p) < vault.MinPasswordLength { - return p, errPasswordTooShort - } - - return p, nil } diff --git a/storage/gcs/gcs.go b/storage/gcs/gcs.go index b9a17e987..946352e62 100644 --- a/storage/gcs/gcs.go +++ b/storage/gcs/gcs.go @@ -7,9 +7,13 @@ "io/ioutil" "log" "net/http" + "net/http/httptest" "os" + "runtime" "time" + "github.com/skratchdot/open-golang/open" + "github.com/kopia/kopia/storage" "golang.org/x/net/context" @@ -296,6 +300,61 @@ func writeTokenToFile(filePath string, token *oauth2.Token) error { } func tokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { + if runtime.GOOS == "windows" { + return tokenFromWebLocalServer(ctx, config) + } + + // On non-SSH Unix, that has X11 configured use local web server. + if os.Getenv("DISPLAY") != "" && os.Getenv("SSH_CLIENT") == "" { + return tokenFromWebLocalServer(ctx, config) + } + + // Otherwise fall back to asking user to manually copy/paste the code. + return tokenFromWebManual(ctx, config) +} + +func tokenFromWebLocalServer(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { + ch := make(chan string) + randState := fmt.Sprintf("st%d", time.Now().UnixNano()) + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path == "/favicon.ico" { + http.Error(rw, "", 404) + return + } + if req.FormValue("state") != randState { + log.Printf("State doesn't match: req = %#v", req) + http.Error(rw, "", 500) + return + } + if code := req.FormValue("code"); code != "" { + fmt.Fprintf(rw, "

Success

Authorized.") + rw.(http.Flusher).Flush() + ch <- code + return + } + log.Printf("no code") + http.Error(rw, "", 500) + })) + defer ts.Close() + + config.RedirectURL = ts.URL + authURL := config.AuthCodeURL(randState) + go open.Start(authURL) + fmt.Println("Opening URL in web browser to get OAuth2 authorization token:") + fmt.Println() + fmt.Println(" ", authURL) + fmt.Println() + code := <-ch + + token, err := config.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("token exchange error: %v", err) + } + + return token, nil +} + +func tokenFromWebManual(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { config.RedirectURL = "urn:ietf:wg:oauth:2.0:oob" authURL := config.AuthCodeURL("") var code string