Skip to content

Commit 79fcca1

Browse files
author
Jan Roehrich
committed
Fix race condition between service account availability and webhook invocation
1 parent ba509d3 commit 79fcca1

File tree

5 files changed

+90
-27
lines changed

5 files changed

+90
-27
lines changed

main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ func main() {
8686

8787
debug := flag.Bool("enable-debugging-handlers", false, "Enable debugging handlers. Currently /debug/alpha/cache is supported")
8888

89+
saLookupGracePeriod := flag.Duration("service-account-lookup-grace-period", 100*time.Millisecond, "The grace period for service account to be available in cache before not mutating a pod. Defaults to 100ms. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance.")
90+
8991
klog.InitFlags(goflag.CommandLine)
9092
// Add klog CommandLine flags to pflag CommandLine
9193
goflag.CommandLine.VisitAll(func(f *goflag.Flag) {
@@ -208,6 +210,7 @@ func main() {
208210
handler.WithServiceAccountCache(saCache),
209211
handler.WithContainerCredentialsConfig(containerCredentialsConfig),
210212
handler.WithRegion(*region),
213+
handler.WithSALookupGraceTime(*saLookupGracePeriod),
211214
)
212215

213216
addr := fmt.Sprintf(":%d", *port)

pkg/cache/cache.go

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ type CacheResponse struct {
4242

4343
type ServiceAccountCache interface {
4444
Start(stop chan struct{})
45-
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64)
45+
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
46+
GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool)
4647
GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64)
4748
// ToJSON returns cache contents as JSON string
4849
ToJSON() string
@@ -60,6 +61,7 @@ type serviceAccountCache struct {
6061
composeRoleArn ComposeRoleArn
6162
defaultTokenExpiration int64
6263
webhookUsage prometheus.Gauge
64+
notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified
6365
}
6466

6567
type ComposeRoleArn struct {
@@ -87,41 +89,51 @@ func init() {
8789
// Get will return the cached configuration of the given ServiceAccount.
8890
// It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any
8991
// ServiceAccount configured through the pod-identity-webhook ConfigMap.
90-
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
92+
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
93+
return c.GetOrNotify(name, namespace, nil)
94+
}
95+
96+
// GetOrNotify will return the cached configuration of the given ServiceAccount.
97+
// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register
98+
// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check
99+
// for a ServiceAccount configured through the pod-identity-webhook ConfigMap.
100+
func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
91101
klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name)
92102
{
93-
resp := c.getSA(name, namespace)
103+
resp := c.getSAorNotify(name, namespace, handler)
94104
if resp != nil && resp.RoleARN != "" {
95-
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
105+
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
96106
}
97107
}
98108
{
99109
resp := c.getCM(name, namespace)
100110
if resp != nil {
101-
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
111+
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
102112
}
103113
}
104114
klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name)
105-
return "", "", false, pkg.DefaultTokenExpiration
115+
return "", "", false, pkg.DefaultTokenExpiration, false
106116
}
107117

108118
// GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials).
109119
// The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility,
110120
// Use these fields if they are set in the sa annotations or config map.
111121
func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
112-
if resp := c.getSA(name, namespace); resp != nil {
122+
if resp := c.getSAorNotify(name, namespace, nil); resp != nil {
113123
return resp.UseRegionalSTS, resp.TokenExpiration
114124
} else if resp := c.getCM(name, namespace); resp != nil {
115125
return resp.UseRegionalSTS, resp.TokenExpiration
116126
}
117127
return false, pkg.DefaultTokenExpiration
118128
}
119129

120-
func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse {
130+
func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *CacheResponse {
121131
c.mu.RLock()
122132
defer c.mu.RUnlock()
123133
resp, ok := c.saCache[namespace+"/"+name]
124-
if !ok {
134+
if !ok && handler != nil {
135+
klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name)
136+
c.notificationHandlers[namespace+"/"+name] = handler
125137
return nil
126138
}
127139
return resp
@@ -212,8 +224,16 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) {
212224
func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) {
213225
c.mu.Lock()
214226
defer c.mu.Unlock()
215-
klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp)
227+
228+
key := namespace + "/" + name
229+
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp)
216230
c.saCache[namespace+"/"+name] = resp
231+
232+
if handler, found := c.notificationHandlers[key]; found {
233+
klog.V(5).Infof("Notifying handler for %q", key)
234+
handler <- 1
235+
delete(c.notificationHandlers, key)
236+
}
217237
}
218238

219239
func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) {
@@ -242,6 +262,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
242262
defaultTokenExpiration: defaultTokenExpiration,
243263
hasSynced: hasSynced,
244264
webhookUsage: webhookUsage,
265+
notificationHandlers: map[string]chan any{},
245266
}
246267

247268
saInformer.Informer().AddEventHandler(

pkg/cache/cache_test.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,18 @@ func TestSaCache(t *testing.T) {
3636
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
3737
}
3838

39-
role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default")
39+
role, aud, useRegionalSTS, tokenExpiration, found := cache.Get("default", "default")
4040

41+
assert.False(t, found, "Expected no cache entry to be found")
4142
if role != "" || aud != "" {
4243
t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration)
4344
}
4445

4546
cache.addSA(testSA)
4647

47-
role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default")
48+
role, aud, useRegionalSTS, tokenExpiration, found = cache.Get("default", "default")
4849

50+
assert.True(t, found, "Expected cache entry to be found")
4951
assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role)
5052
assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud)
5153
assert.True(t, useRegionalSTS, "Expected regional STS to be true, got false")
@@ -157,7 +159,8 @@ func TestNonRegionalSTS(t *testing.T) {
157159
t.Fatalf("cache never called addSA: %v", err)
158160
}
159161

160-
gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default")
162+
gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, found := cache.Get("default", "default")
163+
assert.True(t, found, "Expected cache entry to be found")
161164
if gotRoleArn != roleArn {
162165
t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn)
163166
}
@@ -202,7 +205,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
202205
t.Errorf("failed to build cache: %v", err)
203206
}
204207

205-
role, _, _, _ := c.Get("mysa2", "myns2")
208+
role, _, _, _, _ := c.Get("mysa2", "myns2")
206209
if role == "" {
207210
t.Errorf("cloud not find entry that should have been added")
208211
}
@@ -214,7 +217,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
214217
t.Errorf("failed to build cache: %v", err)
215218
}
216219

217-
role, _, _, _ := c.Get("mysa2", "myns2")
220+
role, _, _, _, _ := c.Get("mysa2", "myns2")
218221
if role == "" {
219222
t.Errorf("cloud not find entry that should have been added")
220223
}
@@ -226,7 +229,7 @@ func TestPopulateCacheFromCM(t *testing.T) {
226229
t.Errorf("failed to build cache: %v", err)
227230
}
228231

229-
role, _, _, _ := c.Get("mysa2", "myns2")
232+
role, _, _, _, _ := c.Get("mysa2", "myns2")
230233
if role != "" {
231234
t.Errorf("found entry that should have been removed")
232235
}
@@ -256,7 +259,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
256259
c.addSA(oldSA)
257260

258261
{
259-
gotRoleArn, _, _, _ := c.Get("default", "default")
262+
gotRoleArn, _, _, _, _ := c.Get("default", "default")
260263
if gotRoleArn != roleArn {
261264
t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn)
262265
}
@@ -268,7 +271,7 @@ func TestSAAnnotationRemoval(t *testing.T) {
268271
c.addSA(newSA)
269272

270273
{
271-
gotRoleArn, _, _, _ := c.Get("default", "default")
274+
gotRoleArn, _, _, _, _ := c.Get("default", "default")
272275
if gotRoleArn != "" {
273276
t.Errorf("got roleArn %v, expected %q", gotRoleArn, "")
274277
}
@@ -323,7 +326,7 @@ func TestCachePrecedence(t *testing.T) {
323326
t.Errorf("failed to build cache: %v", err)
324327
}
325328

326-
role, _, _, exp := c.Get("mysa2", "myns2")
329+
role, _, _, exp, _ := c.Get("mysa2", "myns2")
327330
if role == "" {
328331
t.Errorf("could not find entry that should have been added")
329332
}
@@ -340,7 +343,7 @@ func TestCachePrecedence(t *testing.T) {
340343
}
341344

342345
// Removing sa2 from CM, but SA still exists
343-
role, _, _, exp := c.Get("mysa2", "myns2")
346+
role, _, _, exp, _ := c.Get("mysa2", "myns2")
344347
if role == "" {
345348
t.Errorf("could not find entry that should still exist")
346349
}
@@ -356,7 +359,7 @@ func TestCachePrecedence(t *testing.T) {
356359
c.addSA(sa2)
357360

358361
// Neither cache should return any hits now
359-
role, _, _, _ := c.Get("myns2", "mysa2")
362+
role, _, _, _, _ := c.Get("myns2", "mysa2")
360363
if role != "" {
361364
t.Errorf("found entry that should not exist")
362365
}
@@ -370,7 +373,7 @@ func TestCachePrecedence(t *testing.T) {
370373
t.Errorf("failed to build cache: %v", err)
371374
}
372375

373-
role, _, _, exp := c.Get("mysa2", "myns2")
376+
role, _, _, exp, _ := c.Get("mysa2", "myns2")
374377
if role == "" {
375378
t.Errorf("cloud not find entry that should have been added")
376379
}
@@ -422,7 +425,7 @@ func TestRoleArnComposition(t *testing.T) {
422425

423426
var roleArn string
424427
err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) {
425-
roleArn, _, _, _ = cache.Get("default", "default")
428+
roleArn, _, _, _, _ = cache.Get("default", "default")
426429
return roleArn != "", nil
427430
})
428431
if err != nil {

pkg/cache/fake.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,25 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{}
4444
func (f *FakeServiceAccountCache) Start(chan struct{}) {}
4545

4646
// Get gets a service account from the cache
47-
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
47+
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
4848
f.mu.RLock()
4949
defer f.mu.RUnlock()
5050
resp, ok := f.cache[namespace+"/"+name]
5151
if !ok {
52-
return "", "", false, pkg.DefaultTokenExpiration
52+
return "", "", false, pkg.DefaultTokenExpiration, false
5353
}
54-
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
54+
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
55+
}
56+
57+
// GetOrNotify gets a service account from the cache
58+
func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) (role, aud string, useRegionalSTS bool, tokenExpiration int64, found bool) {
59+
f.mu.RLock()
60+
defer f.mu.RUnlock()
61+
resp, ok := f.cache[namespace+"/"+name]
62+
if !ok {
63+
return "", "", false, pkg.DefaultTokenExpiration, false
64+
}
65+
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, true
5566
}
5667

5768
func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {

pkg/handler/handler.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"path/filepath"
2525
"strconv"
2626
"strings"
27+
"time"
2728

2829
"github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials"
2930

@@ -77,6 +78,12 @@ func WithAnnotationDomain(domain string) ModifierOpt {
7778
return func(m *Modifier) { m.AnnotationDomain = domain }
7879
}
7980

81+
// WithSALookupGraceTime sets the grace time to wait for service accounts to appear in cache
82+
func WithSALookupGraceTime(saLookupGraceTime time.Duration) ModifierOpt {
83+
return func(m *Modifier) { m.saLookupGraceTime = saLookupGraceTime }
84+
85+
}
86+
8087
// NewModifier returns a Modifier with default values
8188
func NewModifier(opts ...ModifierOpt) *Modifier {
8289
mod := &Modifier{
@@ -101,6 +108,7 @@ type Modifier struct {
101108
ContainerCredentialsConfig containercredentials.Config
102109
volName string
103110
tokenName string
111+
saLookupGraceTime time.Duration
104112
}
105113

106114
type patchOperation struct {
@@ -425,7 +433,24 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig {
425433
}
426434

427435
// Use the STS WebIdentity method if set
428-
roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
436+
handler := make(chan any, 1)
437+
roleArn, audience, regionalSTS, tokenExpiration, found := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler)
438+
key := pod.Namespace + "/" + pod.Spec.ServiceAccountName
439+
if !found && m.saLookupGraceTime > 0 {
440+
klog.Warningf("Service account %q not found in the cache. Waiting up to %s to be notified", key, m.saLookupGraceTime)
441+
select {
442+
case <-handler:
443+
roleArn, audience, regionalSTS, tokenExpiration, found = m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
444+
if !found {
445+
klog.Warningf("Service account %q not found in the cache after being notified. Not mutating.", key)
446+
return nil
447+
}
448+
case <-time.After(m.saLookupGraceTime):
449+
klog.Warningf("Service account %q not found in the cache after %s. Not mutating.", key, m.saLookupGraceTime)
450+
return nil
451+
}
452+
}
453+
klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %q: %s", key, roleArn)
429454
if roleArn != "" {
430455
tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration)
431456

0 commit comments

Comments
 (0)