diff --git a/repo/grpc_repository_client.go b/repo/grpc_repository_client.go index ba568ac2f..a5cc326f4 100644 --- a/repo/grpc_repository_client.go +++ b/repo/grpc_repository_client.go @@ -829,18 +829,9 @@ func openGRPCAPIRepository(ctx context.Context, si *APIServerInfo, password stri transportCreds = credentials.NewClientTLSFromCert(nil, "") } - u, err := url.Parse(si.BaseURL) + uri, err := baseURLToURI(si.BaseURL) if err != nil { - return nil, errors.Wrap(err, "unable to parse server URL") - } - - if u.Scheme != "kopia" && u.Scheme != "https" && u.Scheme != "unix+https" { - return nil, errors.Errorf("invalid server address, must be 'https://host:port' or 'unix+https://") - } - - uri := net.JoinHostPort(u.Hostname(), u.Port()) - if u.Scheme == "unix+https" { - uri = "unix:" + u.Path + return nil, errors.Wrap(err, "parsing base URL") } conn, err := grpc.NewClient( @@ -869,6 +860,24 @@ func(ctx context.Context) error { return rep, nil } +func baseURLToURI(baseURL string) (uri string, err error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", errors.Wrap(err, "unable to parse server URL") + } + + if u.Scheme != "kopia" && u.Scheme != "https" && u.Scheme != "unix+https" { + return "", errors.Errorf("invalid server address, must be 'https://host:port' or 'unix+https://") + } + + uri = net.JoinHostPort(u.Hostname(), u.Port()) + if u.Scheme == "unix+https" { + uri = "unix:" + u.Path + } + + return uri, nil +} + func (r *grpcRepositoryClient) getOrEstablishInnerSession(ctx context.Context) (*grpcInnerSession, error) { r.innerSessionMutex.Lock() defer r.innerSessionMutex.Unlock() diff --git a/repo/grpc_repository_client_unit_test.go b/repo/grpc_repository_client_unit_test.go new file mode 100644 index 000000000..029f804d2 --- /dev/null +++ b/repo/grpc_repository_client_unit_test.go @@ -0,0 +1,64 @@ +package repo + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBaseURLToURI(t *testing.T) { + for _, tc := range []struct { + name string + baseURL string + expURI string + expErrMsg string + }{ + { + name: "ipv4", + baseURL: "https://1.2.3.4:5678", + expURI: "1.2.3.4:5678", + expErrMsg: "", + }, + { + name: "ipv6", + baseURL: "https://[2600:1f14:253f:ef00:87b9::10]:51515", + expURI: "[2600:1f14:253f:ef00:87b9::10]:51515", + expErrMsg: "", + }, + { + name: "unix https scheme", + baseURL: "unix+https:///tmp/kopia-test606141450/sock", + expURI: "unix:/tmp/kopia-test606141450/sock", + expErrMsg: "", + }, + { + name: "kopia scheme", + baseURL: "kopia://a:0", + expURI: "a:0", + expErrMsg: "", + }, + { + name: "unix http scheme is invalid", + baseURL: "unix+http:///tmp/kopia-test606141450/sock", + expURI: "", + expErrMsg: "invalid server address", + }, + { + name: "invalid address", + baseURL: "a", + expURI: "", + expErrMsg: "invalid server address", + }, + } { + t.Run(tc.name, func(t *testing.T) { + gotURI, err := baseURLToURI(tc.baseURL) + if tc.expErrMsg != "" { + require.ErrorContains(t, err, tc.expErrMsg) + return + } + + require.NoError(t, err) + require.Equal(t, tc.expURI, gotURI) + }) + } +}