From 2d9b7f1e336b812838918014d87e2cb404cbff23 Mon Sep 17 00:00:00 2001 From: Julio Lopez <1953782+julio-lopez@users.noreply.github.com> Date: Tue, 29 Apr 2025 23:47:41 -0700 Subject: [PATCH] feat(providers): Azure Blob client certificate authentication (#4535) Allow the use of a client certificate when authenticating to an Azure Blob storage provider. Tests included. Credit: @DeepikaDixit Authored-by: Deepika Dixit --- cli/storage_azure.go | 1 + repo/blob/azure/azure_options.go | 3 + repo/blob/azure/azure_storage.go | 86 ++++++++++++++++----------- repo/blob/azure/azure_storage_test.go | 39 ++++++++++++ 4 files changed, 94 insertions(+), 35 deletions(-) diff --git a/cli/storage_azure.go b/cli/storage_azure.go index b5f0803b8..64b9ba7a3 100644 --- a/cli/storage_azure.go +++ b/cli/storage_azure.go @@ -25,6 +25,7 @@ func (c *storageAzureFlags) Setup(svc StorageProviderServices, cmd *kingpin.CmdC cmd.Flag("tenant-id", "Azure service principle tenant ID (overrides AZURE_TENANT_ID environment variable)").Envar(svc.EnvName("AZURE_TENANT_ID")).StringVar(&c.azOptions.TenantID) cmd.Flag("client-id", "Azure service principle client ID (overrides AZURE_CLIENT_ID environment variable)").Envar(svc.EnvName("AZURE_CLIENT_ID")).StringVar(&c.azOptions.ClientID) cmd.Flag("client-secret", "Azure service principle client secret (overrides AZURE_CLIENT_SECRET environment variable)").Envar(svc.EnvName("AZURE_CLIENT_SECRET")).StringVar(&c.azOptions.ClientSecret) + cmd.Flag("client-cert", "Azure client certificate (overrides AZURE_CLIENT_CERT environment variable)").Envar(svc.EnvName("AZURE_CLIENT_CERT")).StringVar(&c.azOptions.ClientCert) commonThrottlingFlags(cmd, &c.azOptions.Limits) diff --git a/repo/blob/azure/azure_options.go b/repo/blob/azure/azure_options.go index 80d11a504..34f733556 100644 --- a/repo/blob/azure/azure_options.go +++ b/repo/blob/azure/azure_options.go @@ -28,6 +28,9 @@ type Options struct { ClientID string ClientSecret string + // ClientCert are used for creating ClientCertificateCredentials + ClientCert string + StorageDomain string `json:"storageDomain,omitempty"` throttling.Limits diff --git a/repo/blob/azure/azure_storage.go b/repo/blob/azure/azure_storage.go index 4c4379acf..85599a2e3 100644 --- a/repo/blob/azure/azure_storage.go +++ b/repo/blob/azure/azure_storage.go @@ -365,11 +365,6 @@ func New(ctx context.Context, opt *Options, isCreate bool) (blob.Storage, error) return nil, errors.New("container name must be specified") } - var ( - service *azblob.Client - serviceErr error - ) - storageDomain := opt.StorageDomain if storageDomain == "" { storageDomain = "blob.core.windows.net" @@ -377,36 +372,7 @@ func New(ctx context.Context, opt *Options, isCreate bool) (blob.Storage, error) storageHostname := fmt.Sprintf("%v.%v", opt.StorageAccount, storageDomain) - switch { - // shared access signature - case opt.SASToken != "": - service, serviceErr = azblob.NewClientWithNoCredential( - fmt.Sprintf("https://%s?%s", storageHostname, opt.SASToken), nil) - - // storage account access key - case opt.StorageKey != "": - // create a credentials object. - cred, err := azblob.NewSharedKeyCredential(opt.StorageAccount, opt.StorageKey) - if err != nil { - return nil, errors.Wrap(err, "unable to initialize storage access key credentials") - } - - service, serviceErr = azblob.NewClientWithSharedKeyCredential( - fmt.Sprintf("https://%s/", storageHostname), cred, nil, - ) - // client secret - case opt.TenantID != "" && opt.ClientID != "" && opt.ClientSecret != "": - cred, err := azidentity.NewClientSecretCredential(opt.TenantID, opt.ClientID, opt.ClientSecret, nil) - if err != nil { - return nil, errors.Wrap(err, "unable to initialize client secret credential") - } - - service, serviceErr = azblob.NewClient(fmt.Sprintf("https://%s/", storageHostname), cred, nil) - - default: - return nil, errors.New("one of the storage key, SAS token or client secret must be provided") - } - + service, serviceErr := getAZService(opt, storageHostname) if serviceErr != nil { return nil, errors.Wrap(serviceErr, "opening azure service") } @@ -437,6 +403,56 @@ func New(ctx context.Context, opt *Options, isCreate bool) (blob.Storage, error) return az, nil } +func getAZService(opt *Options, storageHostname string) (*azblob.Client, error) { + var ( + service *azblob.Client + serviceErr error + ) + + switch { + // shared access signature + case opt.SASToken != "": + service, serviceErr = azblob.NewClientWithNoCredential( + fmt.Sprintf("https://%s?%s", storageHostname, opt.SASToken), nil) + // storage account access key + case opt.StorageKey != "": + // create a credentials object. + cred, err := azblob.NewSharedKeyCredential(opt.StorageAccount, opt.StorageKey) + if err != nil { + return nil, errors.Wrap(err, "unable to initialize storage access key credentials") + } + + service, serviceErr = azblob.NewClientWithSharedKeyCredential( + fmt.Sprintf("https://%s/", storageHostname), cred, nil, + ) + // client secret + case opt.TenantID != "" && opt.ClientID != "" && opt.ClientSecret != "": + cred, err := azidentity.NewClientSecretCredential(opt.TenantID, opt.ClientID, opt.ClientSecret, nil) + if err != nil { + return nil, errors.Wrap(err, "unable to initialize client secret credential") + } + + service, serviceErr = azblob.NewClient(fmt.Sprintf("https://%s/", storageHostname), cred, nil) + // client certificate + case opt.TenantID != "" && opt.ClientID != "" && opt.ClientCert != "": + certs, key, certErr := azidentity.ParseCertificates([]byte(opt.ClientCert), nil) + if certErr != nil { + return nil, errors.Wrap(certErr, "failed to read client cert") + } + + cred, credErr := azidentity.NewClientCertificateCredential(opt.TenantID, opt.ClientID, certs, key, nil) + if credErr != nil { + return nil, errors.Wrap(credErr, "unable to initialize client cert credential") + } + + service, serviceErr = azblob.NewClient(fmt.Sprintf("https://%s/", storageHostname), cred, nil) + default: + return nil, errors.New("one of the storage key, SAS token, client secret or client certificate must be provided") + } + + return service, errors.Wrap(serviceErr, "unable to create azure client") +} + func init() { blob.AddSupportedStorage(azStorageType, Options{}, New) } diff --git a/repo/blob/azure/azure_storage_test.go b/repo/blob/azure/azure_storage_test.go index c12909c6c..47d49778c 100644 --- a/repo/blob/azure/azure_storage_test.go +++ b/repo/blob/azure/azure_storage_test.go @@ -33,6 +33,7 @@ testStorageTenantIDEnv = "KOPIA_AZURE_TEST_TENANT_ID" testStorageClientIDEnv = "KOPIA_AZURE_TEST_CLIENT_ID" testStorageClientSecretEnv = "KOPIA_AZURE_TEST_CLIENT_SECRET" + testStorageClientCertEnv = "KOPIA_AZURE_TEST_CLIENT_CERT" ) func getEnvOrSkip(t *testing.T, name string) string { @@ -201,6 +202,44 @@ func TestAzureStorageClientSecret(t *testing.T) { require.NoError(t, providervalidation.ValidateProvider(ctx, st, blobtesting.TestValidationOptions)) } +func TestAzureStorageClientCertificate(t *testing.T) { + t.Parallel() + testutil.ProviderTest(t) + + container := getEnvOrSkip(t, testContainerEnv) + storageAccount := getEnvOrSkip(t, testStorageAccountEnv) + tenantID := getEnvOrSkip(t, testStorageTenantIDEnv) + clientID := getEnvOrSkip(t, testStorageClientIDEnv) + clientCert := getEnvOrSkip(t, testStorageClientCertEnv) + + data := make([]byte, 8) + rand.Read(data) + + ctx := testlogging.Context(t) + + // use context that gets canceled after storage is initialize, + // to verify we do not depend on the original context past initialization. + newctx, cancel := context.WithCancel(ctx) + st, err := azure.New(newctx, &azure.Options{ + Container: container, + StorageAccount: storageAccount, + TenantID: tenantID, + ClientID: clientID, + ClientCert: clientCert, + Prefix: fmt.Sprintf("sastest-%v-%x/", clock.Now().Unix(), data), + }, false) + + require.NoError(t, err) + cancel() + + defer st.Close(ctx) + defer blobtesting.CleanupOldData(ctx, t, st, 0) + + blobtesting.VerifyStorage(ctx, t, st, blob.PutOptions{}) + blobtesting.AssertConnectionInfoRoundTrips(ctx, t, st) + require.NoError(t, providervalidation.ValidateProvider(ctx, st, blobtesting.TestValidationOptions)) +} + func TestAzureStorageInvalidBlob(t *testing.T) { testutil.ProviderTest(t)