Skip to content

Commit d2178de

Browse files
authored
[azclient] Refactor AuthProvider multi-tenant token credential (#8596)
* Refactor AuthProvider multi-tenant token credential * Add unit tests * Fix license header * Fix lint issues
1 parent a1b5cf0 commit d2178de

File tree

9 files changed

+1820
-142
lines changed

9 files changed

+1820
-142
lines changed

pkg/azclient/auth.go

Lines changed: 38 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -17,156 +17,59 @@ limitations under the License.
1717
package azclient
1818

1919
import (
20-
"context"
20+
"errors"
2121
"fmt"
22-
"os"
2322
"strings"
2423

2524
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
25+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
2626
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
27-
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
28-
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
29-
"github.com/Azure/msi-dataplane/pkg/dataplane"
27+
)
3028

31-
"sigs.k8s.io/cloud-provider-azure/pkg/azclient/armauth"
29+
var (
30+
ErrNoValidAuthMethodFound = errors.New("no valid authentication method found")
3231
)
3332

3433
type AuthProvider struct {
35-
ComputeCredential azcore.TokenCredential
36-
NetworkCredential azcore.TokenCredential
37-
MultiTenantCredential azcore.TokenCredential
38-
CloudConfig cloud.Configuration
34+
ComputeCredential azcore.TokenCredential
35+
AdditionalComputeClientOptions []func(option *arm.ClientOptions)
36+
NetworkCredential azcore.TokenCredential
37+
CloudConfig cloud.Configuration
3938
}
4039

41-
func NewAuthProvider(armConfig *ARMClientConfig, config *AzureAuthConfig, clientOptionsMutFn ...func(option *policy.ClientOptions)) (*AuthProvider, error) {
40+
func NewAuthProvider(
41+
armConfig *ARMClientConfig,
42+
config *AzureAuthConfig,
43+
options ...AuthProviderOption,
44+
) (*AuthProvider, error) {
45+
opts := defaultAuthProviderOptions()
46+
for _, opt := range options {
47+
opt(opts)
48+
}
49+
4250
clientOption, _, err := GetAzCoreClientOption(armConfig)
4351
if err != nil {
4452
return nil, err
4553
}
46-
for _, fn := range clientOptionsMutFn {
54+
for _, fn := range opts.ClientOptionsMutFn {
4755
fn(clientOption)
4856
}
49-
var computeCredential azcore.TokenCredential
50-
var networkTokenCredential azcore.TokenCredential
51-
var multiTenantCredential azcore.TokenCredential
52-
53-
// federatedIdentityCredential is used for workload identity federation
54-
if aadFederatedTokenFile, enabled := config.GetAzureFederatedTokenFile(); enabled {
55-
computeCredential, err = azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
56-
ClientOptions: *clientOption,
57-
ClientID: config.GetAADClientID(),
58-
TenantID: armConfig.GetTenantID(),
59-
TokenFilePath: aadFederatedTokenFile,
60-
})
61-
if err != nil {
62-
return nil, err
63-
}
64-
}
65-
// managedIdentityCredential is used for managed identity extension
66-
if computeCredential == nil && config.UseManagedIdentityExtension {
67-
credOptions := &azidentity.ManagedIdentityCredentialOptions{
68-
ClientOptions: *clientOption,
69-
}
70-
if len(config.UserAssignedIdentityID) > 0 {
71-
if strings.Contains(strings.ToUpper(config.UserAssignedIdentityID), "/SUBSCRIPTIONS/") {
72-
credOptions.ID = azidentity.ResourceID(config.UserAssignedIdentityID)
73-
} else {
74-
credOptions.ID = azidentity.ClientID(config.UserAssignedIdentityID)
75-
}
76-
}
77-
computeCredential, err = azidentity.NewManagedIdentityCredential(credOptions)
78-
if err != nil {
79-
return nil, err
80-
}
81-
if config.AuxiliaryTokenProvider != nil && IsMultiTenant(armConfig) {
82-
networkTokenCredential, err = armauth.NewKeyVaultCredential(
83-
computeCredential,
84-
config.AuxiliaryTokenProvider.SecretResourceID(),
85-
)
86-
if err != nil {
87-
return nil, fmt.Errorf("create KeyVaultCredential for auxiliary token provider: %w", err)
88-
}
89-
}
90-
}
91-
92-
// Client secret authentication
93-
if computeCredential == nil && len(config.GetAADClientSecret()) > 0 {
94-
credOptions := &azidentity.ClientSecretCredentialOptions{
95-
ClientOptions: *clientOption,
96-
}
97-
computeCredential, err = azidentity.NewClientSecretCredential(armConfig.GetTenantID(), config.GetAADClientID(), config.GetAADClientSecret(), credOptions)
98-
if err != nil {
99-
return nil, err
100-
}
101-
if IsMultiTenant(armConfig) {
102-
credOptions := &azidentity.ClientSecretCredentialOptions{
103-
ClientOptions: *clientOption,
104-
}
105-
networkTokenCredential, err = azidentity.NewClientSecretCredential(armConfig.NetworkResourceTenantID, config.GetAADClientID(), config.AADClientSecret, credOptions)
106-
if err != nil {
107-
return nil, err
108-
}
109-
110-
credOptions = &azidentity.ClientSecretCredentialOptions{
111-
ClientOptions: *clientOption,
112-
AdditionallyAllowedTenants: []string{armConfig.NetworkResourceTenantID},
113-
}
114-
multiTenantCredential, err = azidentity.NewClientSecretCredential(armConfig.GetTenantID(), config.GetAADClientID(), config.GetAADClientSecret(), credOptions)
115-
if err != nil {
116-
return nil, err
117-
}
118-
119-
}
120-
}
12157

122-
// ClientCertificateCredential is used for client certificate
123-
if computeCredential == nil && len(config.AADClientCertPath) > 0 {
124-
credOptions := &azidentity.ClientCertificateCredentialOptions{
125-
ClientOptions: *clientOption,
126-
SendCertificateChain: true,
127-
}
128-
certData, err := os.ReadFile(config.AADClientCertPath)
129-
if err != nil {
130-
return nil, fmt.Errorf("reading the client certificate from file %s: %w", config.AADClientCertPath, err)
131-
}
132-
certificate, privateKey, err := azidentity.ParseCertificates(certData, []byte(config.AADClientCertPassword))
133-
if err != nil {
134-
return nil, fmt.Errorf("decoding the client certificate: %w", err)
135-
}
136-
computeCredential, err = azidentity.NewClientCertificateCredential(armConfig.GetTenantID(), config.GetAADClientID(), certificate, privateKey, credOptions)
137-
if err != nil {
138-
return nil, err
139-
}
140-
if IsMultiTenant(armConfig) {
141-
networkTokenCredential, err = azidentity.NewClientCertificateCredential(armConfig.NetworkResourceTenantID, config.GetAADClientID(), certificate, privateKey, credOptions)
142-
if err != nil {
143-
return nil, err
144-
}
145-
credOptions = &azidentity.ClientCertificateCredentialOptions{
146-
ClientOptions: *clientOption,
147-
AdditionallyAllowedTenants: []string{armConfig.NetworkResourceTenantID},
148-
}
149-
multiTenantCredential, err = azidentity.NewClientCertificateCredential(armConfig.GetTenantID(), config.GetAADClientID(), certificate, privateKey, credOptions)
150-
if err != nil {
151-
return nil, err
152-
}
153-
}
58+
aadFederatedTokenFile, federatedTokenEnabled := config.GetAzureFederatedTokenFile()
59+
switch {
60+
case federatedTokenEnabled:
61+
return newAuthProviderWithWorkloadIdentity(aadFederatedTokenFile, armConfig, config, clientOption, opts)
62+
case config.UseManagedIdentityExtension:
63+
return newAuthProviderWithManagedIdentity(armConfig, config, clientOption, opts)
64+
case len(config.GetAADClientSecret()) > 0:
65+
return newAuthProviderWithServicePrincipalClientSecret(armConfig, config, clientOption, opts)
66+
case len(config.AADClientCertPath) > 0:
67+
return newAuthProviderWithServicePrincipalClientCertificate(armConfig, config, clientOption, opts)
68+
case len(config.AADMSIDataPlaneIdentityPath) > 0:
69+
return newAuthProviderWithUserAssignedIdentity(config, clientOption, opts)
70+
default:
71+
return nil, ErrNoValidAuthMethodFound
15472
}
155-
156-
// UserAssignedIdentityCredentials authentication
157-
if computeCredential == nil && len(config.AADMSIDataPlaneIdentityPath) > 0 {
158-
computeCredential, err = dataplane.NewUserAssignedIdentityCredential(context.Background(), config.AADMSIDataPlaneIdentityPath, dataplane.WithClientOpts(azcore.ClientOptions{Cloud: clientOption.Cloud}))
159-
if err != nil {
160-
return nil, err
161-
}
162-
}
163-
164-
return &AuthProvider{
165-
ComputeCredential: computeCredential,
166-
NetworkCredential: networkTokenCredential,
167-
MultiTenantCredential: multiTenantCredential,
168-
CloudConfig: clientOption.Cloud,
169-
}, nil
17073
}
17174

17275
func (factory *AuthProvider) GetAzIdentity() azcore.TokenCredential {
@@ -180,18 +83,11 @@ func (factory *AuthProvider) GetNetworkAzIdentity() azcore.TokenCredential {
18083
return factory.ComputeCredential
18184
}
18285

183-
func (factory *AuthProvider) GetMultiTenantIdentity() azcore.TokenCredential {
184-
if factory.MultiTenantCredential != nil {
185-
return factory.MultiTenantCredential
186-
}
187-
return factory.ComputeCredential
188-
}
189-
190-
func (factory *AuthProvider) IsMultiTenantModeEnabled() bool {
191-
return factory.MultiTenantCredential != nil
86+
func (factory *AuthProvider) DefaultTokenScope() string {
87+
return DefaultTokenScopeFor(factory.CloudConfig)
19288
}
19389

194-
func (factory *AuthProvider) DefaultTokenScope() string {
195-
audience := factory.CloudConfig.Services[cloud.ResourceManager].Audience
90+
func DefaultTokenScopeFor(cloudCfg cloud.Configuration) string {
91+
audience := cloudCfg.Services[cloud.ResourceManager].Audience
19692
return fmt.Sprintf("%s/.default", strings.TrimRight(audience, "/"))
19793
}

pkg/azclient/auth_fake_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package azclient
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"sync/atomic"
23+
"testing"
24+
"time"
25+
26+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
27+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
28+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
29+
"github.com/stretchr/testify/assert"
30+
)
31+
32+
var (
33+
incCounter = atomic.Int64{}
34+
)
35+
36+
type fakeTokenCredential struct {
37+
ID string
38+
}
39+
40+
func newFakeTokenCredential() *fakeTokenCredential {
41+
id := fmt.Sprintf("fake-token-credential-%d-%d", incCounter.Add(1), time.Now().UnixNano())
42+
return &fakeTokenCredential{ID: id}
43+
}
44+
45+
func (f *fakeTokenCredential) GetToken(
46+
_ context.Context,
47+
_ policy.TokenRequestOptions,
48+
) (azcore.AccessToken, error) {
49+
panic("not implemented")
50+
}
51+
52+
type AuthProviderAssertions func(t testing.TB, authProvider *AuthProvider)
53+
54+
func ApplyAssertions(t testing.TB, authProvider *AuthProvider, assertions []AuthProviderAssertions) {
55+
t.Helper()
56+
57+
for _, assertion := range assertions {
58+
assertion(t, authProvider)
59+
}
60+
}
61+
62+
func AssertComputeTokenCredential(tokenCredential *fakeTokenCredential) AuthProviderAssertions {
63+
return func(t testing.TB, authProvider *AuthProvider) {
64+
t.Helper()
65+
66+
assert.NotNil(t, authProvider.ComputeCredential)
67+
68+
cred, ok := authProvider.ComputeCredential.(*fakeTokenCredential)
69+
assert.True(t, ok, "expected a fake token credential")
70+
assert.Equal(t, tokenCredential.ID, cred.ID)
71+
}
72+
}
73+
74+
func AssertNetworkTokenCredential(tokenCredential *fakeTokenCredential) AuthProviderAssertions {
75+
return func(t testing.TB, authProvider *AuthProvider) {
76+
t.Helper()
77+
78+
assert.NotNil(t, authProvider.NetworkCredential)
79+
80+
cred, ok := authProvider.NetworkCredential.(*fakeTokenCredential)
81+
assert.True(t, ok, "expected a fake token credential")
82+
assert.Equal(t, tokenCredential.ID, cred.ID)
83+
}
84+
}
85+
86+
func AssertNilNetworkTokenCredential() AuthProviderAssertions {
87+
return func(t testing.TB, authProvider *AuthProvider) {
88+
t.Helper()
89+
90+
assert.Nil(t, authProvider.NetworkCredential)
91+
}
92+
}
93+
94+
func AssertEmptyAdditionalComputeClientOptions() AuthProviderAssertions {
95+
return func(t testing.TB, authProvider *AuthProvider) {
96+
t.Helper()
97+
98+
assert.Empty(t, authProvider.AdditionalComputeClientOptions)
99+
}
100+
}
101+
102+
func AssertCloudConfig(expected cloud.Configuration) AuthProviderAssertions {
103+
return func(t testing.TB, authProvider *AuthProvider) {
104+
t.Helper()
105+
106+
assert.Equal(t, expected, authProvider.CloudConfig)
107+
}
108+
}

0 commit comments

Comments
 (0)