diff --git a/README.md b/README.md index bb8d9e13e..24bcc519f 100644 --- a/README.md +++ b/README.md @@ -143,37 +143,38 @@ When running a container with a non-root user, you need to give the container ac ``` Usage of amazon-eks-pod-identity-webhook: - --add_dir_header If true, adds the file directory to the header - --alsologtostderr log to standard error as well as files - --annotation-prefix string The Service Account annotation to look for (default "eks.amazonaws.com") - --aws-default-region string If set, AWS_DEFAULT_REGION and AWS_REGION will be set to this value in mutated containers - --enable-debugging-handlers Enable debugging handlers. Currently /debug/alpha/cache is supported - --in-cluster Use in-cluster authentication and certificate request API (default true) - --kube-api string (out-of-cluster) The url to the API server - --kubeconfig string (out-of-cluster) Absolute path to the API server kubeconfig file - --log_backtrace_at traceLocation when logging hits line file:N, emit a stack trace (default :0) - --log_dir string If non-empty, write log files in this directory - --log_file string If non-empty, use this log file - --log_file_max_size uint Defines the maximum size a log file can grow to. Unit is megabytes. If the value is 0, the maximum file size is unlimited. (default 1800) - --logtostderr log to standard error instead of files (default true) - --metrics-port int Port to listen on for metrics (http) (default 9999) - --namespace string (in-cluster) The namespace name this webhook, the TLS secret, and configmap resides in (default "eks") - --port int Port to listen on (default 443) - --service-name string (in-cluster) The service name fronting this webhook (default "pod-identity-webhook") - --skip_headers If true, avoid header prefixes in the log messages - --skip_log_headers If true, avoid headers when opening log files - --stderrthreshold severity logs at or above this threshold go to stderr (default 2) - --sts-regional-endpoint false Whether to inject the AWS_STS_REGIONAL_ENDPOINTS=regional env var in mutated pods. Defaults to false. - --tls-cert string (out-of-cluster) TLS certificate file path (default "/etc/webhook/certs/tls.crt") - --tls-key string (out-of-cluster) TLS key file path (default "/etc/webhook/certs/tls.key") - --tls-secret string (in-cluster) The secret name for storing the TLS serving cert (default "pod-identity-webhook") - --token-audience string The default audience for tokens. Can be overridden by annotation (default "sts.amazonaws.com") - --token-expiration int The token expiration (default 86400) - --token-mount-path string The path to mount tokens (default "/var/run/secrets/eks.amazonaws.com/serviceaccount") - -v, --v Level number for the log level verbosity - --version Display the version and exit - --vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging - --watch-config-map Enables watching serviceaccounts that are configured through the pod-identity-webhook configmap instead of using annotations + --add_dir_header If true, adds the file directory to the header + --alsologtostderr log to standard error as well as files + --annotation-prefix string The Service Account annotation to look for (default "eks.amazonaws.com") + --aws-default-region string If set, AWS_DEFAULT_REGION and AWS_REGION will be set to this value in mutated containers + --enable-debugging-handlers Enable debugging handlers. Currently /debug/alpha/cache is supported + --in-cluster Use in-cluster authentication and certificate request API (default true) + --kube-api string (out-of-cluster) The url to the API server + --kubeconfig string (out-of-cluster) Absolute path to the API server kubeconfig file + --log_backtrace_at traceLocation when logging hits line file:N, emit a stack trace (default :0) + --log_dir string If non-empty, write log files in this directory + --log_file string If non-empty, use this log file + --log_file_max_size uint Defines the maximum size a log file can grow to. Unit is megabytes. If the value is 0, the maximum file size is unlimited. (default 1800) + --logtostderr log to standard error instead of files (default true) + --metrics-port int Port to listen on for metrics (http) (default 9999) + --namespace string (in-cluster) The namespace name this webhook, the TLS secret, and configmap resides in (default "eks") + --port int Port to listen on (default 443) + --service-name string (in-cluster) The service name fronting this webhook (default "pod-identity-webhook") + --service-account-lookup-grace-period The grace period for service account to be available in cache before not mutating a pod. Set to 0 to deactivate waiting. Carefully use higher values as it may have significant impact on Kubernetes' pod scheduling performance. (default 100ms) + --skip_headers If true, avoid header prefixes in the log messages + --skip_log_headers If true, avoid headers when opening log files + --stderrthreshold severity logs at or above this threshold go to stderr (default 2) + --sts-regional-endpoint false Whether to inject the AWS_STS_REGIONAL_ENDPOINTS=regional env var in mutated pods. Defaults to false. + --tls-cert string (out-of-cluster) TLS certificate file path (default "/etc/webhook/certs/tls.crt") + --tls-key string (out-of-cluster) TLS key file path (default "/etc/webhook/certs/tls.key") + --tls-secret string (in-cluster) The secret name for storing the TLS serving cert (default "pod-identity-webhook") + --token-audience string The default audience for tokens. Can be overridden by annotation (default "sts.amazonaws.com") + --token-expiration int The token expiration (default 86400) + --token-mount-path string The path to mount tokens (default "/var/run/secrets/eks.amazonaws.com/serviceaccount") + -v, --v Level number for the log level verbosity + --version Display the version and exit + --vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging + --watch-config-map Enables watching serviceaccounts that are configured through the pod-identity-webhook configmap instead of using annotations ``` ### AWS_DEFAULT_REGION Injection diff --git a/main.go b/main.go index 22310b4be..17294659d 100644 --- a/main.go +++ b/main.go @@ -86,6 +86,8 @@ func main() { debug := flag.Bool("enable-debugging-handlers", false, "Enable debugging handlers. Currently /debug/alpha/cache is supported") + saLookupGracePeriod := flag.Duration("service-account-lookup-grace-period", 0, "The grace period for service account to be available in cache before not mutating a pod. Defaults to 0, what deactivates waiting. Carefully use values higher than a bunch of milliseconds as it may have significant impact on Kubernetes' pod scheduling performance.") + klog.InitFlags(goflag.CommandLine) // Add klog CommandLine flags to pflag CommandLine goflag.CommandLine.VisitAll(func(f *goflag.Flag) { @@ -208,6 +210,7 @@ func main() { handler.WithServiceAccountCache(saCache), handler.WithContainerCredentialsConfig(containerCredentialsConfig), handler.WithRegion(*region), + handler.WithSALookupGraceTime(*saLookupGracePeriod), ) addr := fmt.Sprintf(":%d", *port) diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index e411fee8e..55b5885f7 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -33,16 +33,35 @@ import ( "k8s.io/klog/v2" ) -type CacheResponse struct { +type Entry struct { RoleARN string Audience string UseRegionalSTS bool TokenExpiration int64 } +type Request struct { + Name string + Namespace string + RequestNotification bool +} + +func (r Request) CacheKey() string { + return r.Namespace + "/" + r.Name +} + +type Response struct { + RoleARN string + Audience string + UseRegionalSTS bool + TokenExpiration int64 + FoundInCache bool + Notifier <-chan struct{} +} + type ServiceAccountCache interface { Start(stop chan struct{}) - Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) + Get(request Request) Response GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) // ToJSON returns cache contents as JSON string ToJSON() string @@ -50,8 +69,8 @@ type ServiceAccountCache interface { type serviceAccountCache struct { mu sync.RWMutex // guards cache - saCache map[string]*CacheResponse - cmCache map[string]*CacheResponse + saCache map[string]*Entry + cmCache map[string]*Entry hasSynced cache.InformerSynced clientset kubernetes.Interface annotationPrefix string @@ -60,6 +79,8 @@ type serviceAccountCache struct { composeRoleArn ComposeRoleArn defaultTokenExpiration int64 webhookUsage prometheus.Gauge + notificationHandlers map[string]chan struct{} + handlerMu sync.Mutex } type ComposeRoleArn struct { @@ -85,56 +106,81 @@ func init() { } // Get will return the cached configuration of the given ServiceAccount. -// It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any -// ServiceAccount configured through the pod-identity-webhook ConfigMap. -func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { - klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name) +// It will first look at the set of ServiceAccounts configured using annotations. If none is found and a notifier is +// requested, it will register a handler to be notified as soon as a ServiceAccount with given key is populated to the +// cache. Afterward it will check for a ServiceAccount configured through the pod-identity-webhook ConfigMap. +func (c *serviceAccountCache) Get(req Request) Response { + result := Response{ + TokenExpiration: pkg.DefaultTokenExpiration, + } + klog.V(5).Infof("Fetching sa %s from cache", req.CacheKey()) { - resp := c.getSA(name, namespace) - if resp != nil && resp.RoleARN != "" { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + var entry *Entry + entry, result.Notifier = c.getSA(req) + if entry != nil { + result.FoundInCache = true + } + if entry != nil && entry.RoleARN != "" { + result.RoleARN = entry.RoleARN + result.Audience = entry.Audience + result.UseRegionalSTS = entry.UseRegionalSTS + result.TokenExpiration = entry.TokenExpiration + return result } } { - resp := c.getCM(name, namespace) - if resp != nil { - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration + entry := c.getCM(req.Name, req.Namespace) + if entry != nil { + result.FoundInCache = true + result.RoleARN = entry.RoleARN + result.Audience = entry.Audience + result.UseRegionalSTS = entry.UseRegionalSTS + result.TokenExpiration = entry.TokenExpiration + return result } } - klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name) - return "", "", false, pkg.DefaultTokenExpiration + klog.V(5).Infof("Service account %s not found in cache", req.CacheKey()) + return result } // GetCommonConfigurations returns the common configurations that also applies to the new mutation method(i.e Container Credentials). // The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility, // Use these fields if they are set in the sa annotations or config map. func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { - if resp := c.getSA(name, namespace); resp != nil { - return resp.UseRegionalSTS, resp.TokenExpiration - } else if resp := c.getCM(name, namespace); resp != nil { - return resp.UseRegionalSTS, resp.TokenExpiration + if entry, _ := c.getSA(Request{Name: name, Namespace: namespace, RequestNotification: false}); entry != nil { + return entry.UseRegionalSTS, entry.TokenExpiration + } else if entry := c.getCM(name, namespace); entry != nil { + return entry.UseRegionalSTS, entry.TokenExpiration } return false, pkg.DefaultTokenExpiration } -func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse { +func (c *serviceAccountCache) getSA(req Request) (*Entry, chan struct{}) { c.mu.RLock() defer c.mu.RUnlock() - resp, ok := c.saCache[namespace+"/"+name] - if !ok { - return nil + entry, ok := c.saCache[req.CacheKey()] + if !ok && req.RequestNotification { + klog.V(5).Infof("Service Account %s not found in cache, adding notification handler", req.CacheKey()) + c.handlerMu.Lock() + defer c.handlerMu.Unlock() + notifier, found := c.notificationHandlers[req.CacheKey()] + if !found { + notifier = make(chan struct{}) + c.notificationHandlers[req.CacheKey()] = notifier + } + return nil, notifier } - return resp + return entry, nil } -func (c *serviceAccountCache) getCM(name, namespace string) *CacheResponse { +func (c *serviceAccountCache) getCM(name, namespace string) *Entry { c.mu.RLock() defer c.mu.RUnlock() - resp, ok := c.cmCache[namespace+"/"+name] + entry, ok := c.cmCache[namespace+"/"+name] if !ok { return nil } - return resp + return entry } func (c *serviceAccountCache) popSA(name, namespace string) { @@ -164,7 +210,7 @@ func (c *serviceAccountCache) ToJSON() string { } func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { - resp := &CacheResponse{} + entry := &Entry{} arn, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.RoleARNAnnotation] if ok { @@ -178,49 +224,59 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) { } else if !matched { klog.Warningf("arn is invalid: %s", arn) } - resp.RoleARN = arn + entry.RoleARN = arn } - resp.Audience = c.defaultAudience + entry.Audience = c.defaultAudience if audience, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.AudienceAnnotation]; ok { - resp.Audience = audience + entry.Audience = audience } - resp.UseRegionalSTS = c.defaultRegionalSTS + entry.UseRegionalSTS = c.defaultRegionalSTS if useRegionalStr, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.UseRegionalSTSAnnotation]; ok { useRegional, err := strconv.ParseBool(useRegionalStr) if err != nil { klog.V(4).Infof("Ignoring service account %s/%s invalid value for disable-regional-sts annotation", sa.Namespace, sa.Name) } else { - resp.UseRegionalSTS = useRegional + entry.UseRegionalSTS = useRegional } } - resp.TokenExpiration = c.defaultTokenExpiration + entry.TokenExpiration = c.defaultTokenExpiration if tokenExpirationStr, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.TokenExpirationAnnotation]; ok { if tokenExpiration, err := strconv.ParseInt(tokenExpirationStr, 10, 64); err != nil { - klog.V(4).Infof("Found invalid value for token expiration, using %d seconds as default: %v", resp.TokenExpiration, err) + klog.V(4).Infof("Found invalid value for token expiration, using %d seconds as default: %v", entry.TokenExpiration, err) } else { - resp.TokenExpiration = pkg.ValidateMinTokenExpiration(tokenExpiration) + entry.TokenExpiration = pkg.ValidateMinTokenExpiration(tokenExpiration) } } c.webhookUsage.Set(1) - c.setSA(sa.Name, sa.Namespace, resp) + c.setSA(sa.Name, sa.Namespace, entry) } -func (c *serviceAccountCache) setSA(name, namespace string, resp *CacheResponse) { +func (c *serviceAccountCache) setSA(name, namespace string, entry *Entry) { c.mu.Lock() defer c.mu.Unlock() - klog.V(5).Infof("Adding SA %s/%s to SA cache: %+v", namespace, name, resp) - c.saCache[namespace+"/"+name] = resp + + key := namespace + "/" + name + klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, entry) + c.saCache[key] = entry + + c.handlerMu.Lock() + defer c.handlerMu.Unlock() + if handler, found := c.notificationHandlers[key]; found { + klog.V(5).Infof("Notifying handlers for %q", key) + close(handler) + delete(c.notificationHandlers, key) + } } -func (c *serviceAccountCache) setCM(name, namespace string, resp *CacheResponse) { +func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) { c.mu.Lock() defer c.mu.Unlock() - klog.V(5).Infof("Adding SA %s/%s to CM cache: %+v", namespace, name, resp) - c.cmCache[namespace+"/"+name] = resp + klog.V(5).Infof("Adding SA %s/%s to CM cache: %+v", namespace, name, entry) + c.cmCache[namespace+"/"+name] = entry } func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenExpiration int64, saInformer coreinformers.ServiceAccountInformer, cmInformer coreinformers.ConfigMapInformer, composeRoleArn ComposeRoleArn) ServiceAccountCache { @@ -233,8 +289,8 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx } c := &serviceAccountCache{ - saCache: map[string]*CacheResponse{}, - cmCache: map[string]*CacheResponse{}, + saCache: map[string]*Entry{}, + cmCache: map[string]*Entry{}, defaultAudience: defaultAudience, annotationPrefix: prefix, defaultRegionalSTS: defaultRegionalSTS, @@ -242,6 +298,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx defaultTokenExpiration: defaultTokenExpiration, hasSynced: hasSynced, webhookUsage: webhookUsage, + notificationHandlers: map[string]chan struct{}{}, } saInformer.Informer().AddEventHandler( @@ -298,22 +355,22 @@ func (c *serviceAccountCache) populateCacheFromCM(oldCM, newCM *v1.ConfigMap) er return nil } newConfig := newCM.Data["config"] - sas := make(map[string]*CacheResponse) + sas := make(map[string]*Entry) err := json.Unmarshal([]byte(newConfig), &sas) if err != nil { return fmt.Errorf("failed to unmarshal new config %q: %v", newConfig, err) } - for key, resp := range sas { + for key, entry := range sas { parts := strings.Split(key, "/") - if resp.TokenExpiration == 0 { - resp.TokenExpiration = c.defaultTokenExpiration + if entry.TokenExpiration == 0 { + entry.TokenExpiration = c.defaultTokenExpiration } - c.setCM(parts[1], parts[0], resp) + c.setCM(parts[1], parts[0], entry) } if oldCM != nil { oldConfig := oldCM.Data["config"] - oldCache := make(map[string]*CacheResponse) + oldCache := make(map[string]*Entry) err := json.Unmarshal([]byte(oldConfig), &oldCache) if err != nil { return fmt.Errorf("failed to unmarshal old config %q: %v", oldConfig, err) diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 58c495cac..d4a540cfd 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -3,6 +3,7 @@ package cache import ( "fmt" "strconv" + "sync" "testing" "time" @@ -30,26 +31,126 @@ func TestSaCache(t *testing.T) { } cache := &serviceAccountCache{ - saCache: map[string]*CacheResponse{}, + saCache: map[string]*Entry{}, defaultAudience: "sts.amazonaws.com", annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), } - role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default") + resp := cache.Get(Request{Name: "default", Namespace: "default"}) - if role != "" || aud != "" { - t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration) + assert.False(t, resp.FoundInCache, "Expected no cache entry to be found") + if resp.RoleARN != "" || resp.Audience != "" { + t.Errorf("Expected role and aud to be empty, got %v", resp) } cache.addSA(testSA) - role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default") + resp = cache.Get(Request{Name: "default", Namespace: "default"}) - assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role) - assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud) - assert.True(t, useRegionalSTS, "Expected regional STS to be true, got false") - assert.Equal(t, int64(3600), tokenExpiration, "Expected token expiration to be 3600, got %d", tokenExpiration) + assert.True(t, resp.FoundInCache, "Expected cache entry to be found") + assert.Equal(t, roleArn, resp.RoleARN, "Expected role to be %s, got %s", roleArn, resp.RoleARN) + assert.Equal(t, "sts.amazonaws.com", resp.Audience, "Expected aud to be sts.amzonaws.com, got %s", resp.Audience) + assert.True(t, resp.UseRegionalSTS, "Expected regional STS to be true, got false") + assert.Equal(t, int64(3600), resp.TokenExpiration, "Expected token expiration to be 3600, got %d", resp.TokenExpiration) +} + +func TestNotification(t *testing.T) { + reqWithNotification := Request{ + Name: "foo", + Namespace: "default", + RequestNotification: true, + } + reqWithoutNotification := Request{ + Name: "foo", + Namespace: "default", + RequestNotification: false, + } + + t.Run("with one notification handler", func(t *testing.T) { + cache := &serviceAccountCache{ + saCache: map[string]*Entry{}, + notificationHandlers: map[string]chan struct{}{}, + webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + } + + // test that the requested SA is not in the cache + resp := cache.Get(reqWithoutNotification) + assert.False(t, resp.FoundInCache, "Expected no cache entry to be found in cache") + + // fetch with notification + resp = cache.Get(reqWithNotification) + + // asynchronously add the SA to the cache + go func() { + time.Sleep(1 * time.Millisecond) + cache.addSA(&v1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + }) + }() + + // wait for the notification + select { + case <-resp.Notifier: + // expected + // test that the requested SA is now in the cache + resp := cache.Get(reqWithoutNotification) + assert.True(t, resp.FoundInCache, "Expected cache entry to be found in cache") + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for notification") + } + }) + + t.Run("with 10 notification handlers", func(t *testing.T) { + cache := &serviceAccountCache{ + saCache: map[string]*Entry{}, + notificationHandlers: map[string]chan struct{}{}, + webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), + } + + // test that the requested SA is not in the cache + resp := cache.Get(reqWithoutNotification) + assert.False(t, resp.FoundInCache, "Expected no cache entry to be found in cache") + + // fetch with notification + resp = cache.Get(reqWithNotification) + + wg := sync.WaitGroup{} + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + // wait for the notification + select { + case <-resp.Notifier: + // expected + // test that the requested SA is now in the cache + resp := cache.Get(reqWithoutNotification) + assert.True(t, resp.FoundInCache, "Expected cache entry to be found in cache") + case <-time.After(1 * time.Second): + t.Error("timeout waiting for notification") + } + }() + } + + // asynchronously add the SA to the cache + go func() { + time.Sleep(1 * time.Millisecond) + cache.addSA(&v1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo", + Namespace: "default", + }, + }) + }() + + wg.Wait() + }) } func TestNonRegionalSTS(t *testing.T) { @@ -157,18 +258,19 @@ func TestNonRegionalSTS(t *testing.T) { t.Fatalf("cache never called addSA: %v", err) } - gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default") - if gotRoleArn != roleArn { - t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn) + resp := cache.Get(Request{Name: "default", Namespace: "default"}) + assert.True(t, resp.FoundInCache, "Expected cache entry to be found") + if resp.RoleARN != roleArn { + t.Errorf("got roleArn %v, expected %v", resp.RoleARN, roleArn) } - if gotAudience != audience { - t.Errorf("got audience %v, expected %v", gotAudience, audience) + if resp.Audience != audience { + t.Errorf("got audience %v, expected %v", resp.Audience, audience) } - if strconv.Itoa(int(gotTokenExpiration)) != tokenExpiration { - t.Errorf("got token expiration %v, expected %v", gotTokenExpiration, tokenExpiration) + if strconv.Itoa(int(resp.TokenExpiration)) != tokenExpiration { + t.Errorf("got token expiration %v, expected %v", resp.TokenExpiration, tokenExpiration) } - if useRegionalSTS != tc.expectedUseRegionalSts { - t.Errorf("got use regional STS %v, expected %v", useRegionalSTS, tc.expectedUseRegionalSts) + if resp.UseRegionalSTS != tc.expectedUseRegionalSts { + t.Errorf("got use regional STS %v, expected %v", resp.UseRegionalSTS, tc.expectedUseRegionalSts) } }) } @@ -193,7 +295,7 @@ func TestPopulateCacheFromCM(t *testing.T) { } c := serviceAccountCache{ - cmCache: make(map[string]*CacheResponse), + cmCache: make(map[string]*Entry), } { @@ -202,8 +304,8 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") - if role == "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN == "" { t.Errorf("cloud not find entry that should have been added") } } @@ -214,8 +316,8 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") - if role == "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN == "" { t.Errorf("cloud not find entry that should have been added") } } @@ -226,8 +328,8 @@ func TestPopulateCacheFromCM(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, _ := c.Get("mysa2", "myns2") - if role != "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN != "" { t.Errorf("found entry that should have been removed") } } @@ -248,7 +350,7 @@ func TestSAAnnotationRemoval(t *testing.T) { } c := serviceAccountCache{ - saCache: make(map[string]*CacheResponse), + saCache: make(map[string]*Entry), annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), } @@ -256,9 +358,9 @@ func TestSAAnnotationRemoval(t *testing.T) { c.addSA(oldSA) { - gotRoleArn, _, _, _ := c.Get("default", "default") - if gotRoleArn != roleArn { - t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn) + resp := c.Get(Request{Name: "default", Namespace: "default"}) + if resp.RoleARN != roleArn { + t.Errorf("got roleArn %q, expected %q", resp.RoleARN, roleArn) } } @@ -268,9 +370,9 @@ func TestSAAnnotationRemoval(t *testing.T) { c.addSA(newSA) { - gotRoleArn, _, _, _ := c.Get("default", "default") - if gotRoleArn != "" { - t.Errorf("got roleArn %v, expected %q", gotRoleArn, "") + resp := c.Get(Request{Name: "default", Namespace: "default"}) + if resp.RoleARN != "" { + t.Errorf("got roleArn %v, expected %q", resp.RoleARN, "") } } } @@ -309,8 +411,8 @@ func TestCachePrecedence(t *testing.T) { sa2.ObjectMeta.Annotations = make(map[string]string) c := serviceAccountCache{ - saCache: make(map[string]*CacheResponse), - cmCache: make(map[string]*CacheResponse), + saCache: make(map[string]*Entry), + cmCache: make(map[string]*Entry), defaultTokenExpiration: pkg.DefaultTokenExpiration, annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), @@ -323,13 +425,13 @@ func TestCachePrecedence(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, exp := c.Get("mysa2", "myns2") - if role == "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN == "" { t.Errorf("could not find entry that should have been added") } // We expect that the SA still holds presedence - if exp != int64(saTokenExpiration) { - t.Errorf("expected tokenExpiration %d, got %d", saTokenExpiration, exp) + if resp.TokenExpiration != int64(saTokenExpiration) { + t.Errorf("expected tokenExpiration %d, got %d", saTokenExpiration, resp.TokenExpiration) } } @@ -340,14 +442,14 @@ func TestCachePrecedence(t *testing.T) { } // Removing sa2 from CM, but SA still exists - role, _, _, exp := c.Get("mysa2", "myns2") - if role == "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN == "" { t.Errorf("could not find entry that should still exist") } // Note that Get returns default expiration if mapping is not found in the cache. - if exp != int64(saTokenExpiration) { - t.Errorf("expected tokenExpiration %d, got %d", saTokenExpiration, exp) + if resp.TokenExpiration != int64(saTokenExpiration) { + t.Errorf("expected tokenExpiration %d, got %d", saTokenExpiration, resp.TokenExpiration) } } @@ -356,8 +458,8 @@ func TestCachePrecedence(t *testing.T) { c.addSA(sa2) // Neither cache should return any hits now - role, _, _, _ := c.Get("myns2", "mysa2") - if role != "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN != "" { t.Errorf("found entry that should not exist") } @@ -370,13 +472,13 @@ func TestCachePrecedence(t *testing.T) { t.Errorf("failed to build cache: %v", err) } - role, _, _, exp := c.Get("mysa2", "myns2") - if role == "" { + resp := c.Get(Request{Name: "mysa2", Namespace: "myns2"}) + if resp.RoleARN == "" { t.Errorf("cloud not find entry that should have been added") } - if exp != pkg.DefaultTokenExpiration { - t.Errorf("expected tokenExpiration %d, got %d", pkg.DefaultTokenExpiration, exp) + if resp.TokenExpiration != pkg.DefaultTokenExpiration { + t.Errorf("expected tokenExpiration %d, got %d", pkg.DefaultTokenExpiration, resp.TokenExpiration) } } @@ -420,19 +522,19 @@ func TestRoleArnComposition(t *testing.T) { cache.Start(stop) defer close(stop) - var roleArn string + var resp Response err := wait.ExponentialBackoff(wait.Backoff{Duration: 10 * time.Millisecond, Factor: 1.0, Steps: 3}, func() (bool, error) { - roleArn, _, _, _ = cache.Get("default", "default") - return roleArn != "", nil + resp = cache.Get(Request{Name: "default", Namespace: "default"}) + return resp.RoleARN != "", nil }) if err != nil { t.Fatalf("cache never returned role arn %v", err) } - arn, err := awsarn.Parse(roleArn) + arn, err := awsarn.Parse(resp.RoleARN) assert.Nil(t, err, "Expected ARN parsing to succeed") - assert.True(t, awsarn.IsARN(roleArn), "Expected ARN validation to be true, got false") + assert.True(t, awsarn.IsARN(resp.RoleARN), "Expected ARN validation to be true, got false") assert.Equal(t, accountID, arn.AccountID, "Expected account ID to be %s, got %s", accountID, arn.AccountID) assert.Equal(t, resource, arn.Resource, "Expected resource to be %s, got %s", resource, arn.Resource) } @@ -506,8 +608,8 @@ func TestGetCommonConfigurations(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { cache := &serviceAccountCache{ - saCache: map[string]*CacheResponse{}, - cmCache: map[string]*CacheResponse{}, + saCache: map[string]*Entry{}, + cmCache: map[string]*Entry{}, defaultAudience: "sts.amazonaws.com", annotationPrefix: "eks.amazonaws.com", webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}), diff --git a/pkg/cache/debug/debug_test.go b/pkg/cache/debug/debug_test.go index 1f5da5d4f..6712192fd 100644 --- a/pkg/cache/debug/debug_test.go +++ b/pkg/cache/debug/debug_test.go @@ -83,7 +83,7 @@ func TestLister(t *testing.T) { t.Errorf("Failed to read response: %v", err) return } - m := map[string]cache.CacheResponse{} + m := map[string]cache.Entry{} err = json.Unmarshal(responseBytes, &m) if err != nil { t.Errorf("Failed to unmarshal: %v", err) diff --git a/pkg/cache/fake.go b/pkg/cache/fake.go index 0f5a67869..eeae6629e 100644 --- a/pkg/cache/fake.go +++ b/pkg/cache/fake.go @@ -12,12 +12,12 @@ import ( // FakeServiceAccountCache is a goroutine safe cache for testing type FakeServiceAccountCache struct { mu sync.RWMutex // guards cache - cache map[string]*CacheResponse + cache map[string]*Entry } func NewFakeServiceAccountCache(accounts ...*v1.ServiceAccount) *FakeServiceAccountCache { c := &FakeServiceAccountCache{ - cache: map[string]*CacheResponse{}, + cache: map[string]*Entry{}, } for _, sa := range accounts { arn, _ := sa.Annotations["eks.amazonaws.com/role-arn"] @@ -44,14 +44,20 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{} func (f *FakeServiceAccountCache) Start(chan struct{}) {} // Get gets a service account from the cache -func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) { +func (f *FakeServiceAccountCache) Get(req Request) Response { f.mu.RLock() defer f.mu.RUnlock() - resp, ok := f.cache[namespace+"/"+name] + resp, ok := f.cache[req.CacheKey()] if !ok { - return "", "", false, pkg.DefaultTokenExpiration + return Response{TokenExpiration: pkg.DefaultTokenExpiration} + } + return Response{ + RoleARN: resp.RoleARN, + Audience: resp.Audience, + UseRegionalSTS: resp.UseRegionalSTS, + TokenExpiration: resp.TokenExpiration, + FoundInCache: true, } - return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration } func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) { @@ -68,7 +74,7 @@ func (f *FakeServiceAccountCache) GetCommonConfigurations(name, namespace string func (f *FakeServiceAccountCache) Add(name, namespace, role, aud string, regionalSTS bool, tokenExpiration int64) { f.mu.Lock() defer f.mu.Unlock() - f.cache[namespace+"/"+name] = &CacheResponse{ + f.cache[namespace+"/"+name] = &Entry{ RoleARN: role, Audience: aud, UseRegionalSTS: regionalSTS, diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1425ca14b..31f852654 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -24,6 +24,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/aws/amazon-eks-pod-identity-webhook/pkg/containercredentials" @@ -77,6 +78,12 @@ func WithAnnotationDomain(domain string) ModifierOpt { return func(m *Modifier) { m.AnnotationDomain = domain } } +// WithSALookupGraceTime sets the grace time to wait for service accounts to appear in cache +func WithSALookupGraceTime(saLookupGraceTime time.Duration) ModifierOpt { + return func(m *Modifier) { m.saLookupGraceTime = saLookupGraceTime } + +} + // NewModifier returns a Modifier with default values func NewModifier(opts ...ModifierOpt) *Modifier { mod := &Modifier{ @@ -101,6 +108,7 @@ type Modifier struct { ContainerCredentialsConfig containercredentials.Config volName string tokenName string + saLookupGraceTime time.Duration } type patchOperation struct { @@ -425,21 +433,38 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig { } // Use the STS WebIdentity method if set - roleArn, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace) - if roleArn != "" { - tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, tokenExpiration) + request := cache.Request{Namespace: pod.Namespace, Name: pod.Spec.ServiceAccountName, RequestNotification: true} + response := m.Cache.Get(request) + if !response.FoundInCache && m.saLookupGraceTime > 0 { + klog.Warningf("Service account %s not found in the cache. Waiting up to %s to be notified", request.CacheKey(), m.saLookupGraceTime) + select { + case <-response.Notifier: + request = cache.Request{Namespace: pod.Namespace, Name: pod.Spec.ServiceAccountName, RequestNotification: false} + response = m.Cache.Get(request) + if !response.FoundInCache { + klog.Warningf("Service account %s not found in the cache after being notified. Not mutating.", request.CacheKey()) + return nil + } + case <-time.After(m.saLookupGraceTime): + klog.Warningf("Service account %s not found in the cache after %s. Not mutating.", request.CacheKey(), m.saLookupGraceTime) + return nil + } + } + klog.V(5).Infof("Value of roleArn after after cache retrieval for service account %s: %s", request.CacheKey(), response.RoleARN) + if response.RoleARN != "" { + tokenExpiration, containersToSkip := m.parsePodAnnotations(pod, response.TokenExpiration) webhookPodCount.WithLabelValues("sts_web_identity").Inc() return &podPatchConfig{ ContainersToSkip: containersToSkip, TokenExpiration: tokenExpiration, - UseRegionalSTS: regionalSTS, - Audience: audience, + UseRegionalSTS: response.UseRegionalSTS, + Audience: response.Audience, MountPath: m.MountPath, VolumeName: m.volName, TokenPath: m.tokenName, - WebIdentityPatchConfig: &webIdentityPatchConfig{RoleArn: roleArn}, + WebIdentityPatchConfig: &webIdentityPatchConfig{RoleArn: response.RoleARN}, ContainerCredentialsPatchConfig: nil, } } diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index cb6aada5f..62c502f1d 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/assert" "io" "io/ioutil" + "k8s.io/apimachinery/pkg/types" "net/http" "net/http/httptest" "reflect" @@ -36,12 +37,15 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) +const uuid = "918ef1dc-928f-4525-99ef-988389f263c3" + func TestMutatePod(t *testing.T) { testServiceAccount := &v1.ServiceAccount{} testServiceAccount.Name = "default" testServiceAccount.Namespace = "default" testServiceAccount.Annotations = map[string]string{ - "eks.amazonaws.com/role-arn": "arn:aws:iam::111122223333:role/s3-reader", + "eks.amazonaws.com/role-arn": "arn:aws:iam::111122223333:role/s3-reader", + "eks.amazonaws.com/token-expiration": "3600", } modifier := NewModifier( @@ -63,6 +67,11 @@ func TestMutatePod(t *testing.T) { &v1beta1.AdmissionReview{Request: nil}, &v1beta1.AdmissionResponse{Result: &metav1.Status{Message: "bad content"}}, }, + { + "ValidRequest", + getValidReview(rawPodWithoutVolume), + getValidHandlerResponse(""), + }, } for _, c := range cases { @@ -74,6 +83,22 @@ func TestMutatePod(t *testing.T) { want, _ := json.MarshalIndent(c.response, "", " ") t.Errorf("Unexpected response. Got \n%s\n wanted \n%s", string(got), string(want)) } + var expectedPatchOps, actualPatchOps []byte + if len(response.Patch) > 0 { + patchOps := make([]patchOperation, 0) + if err := json.Unmarshal(response.Patch, &patchOps); err != nil { + t.Errorf("Failed to unmarshal patch: %v", err) + } + actualPatchOps, _ = json.MarshalIndent(patchOps, "", " ") + } + if len(c.response.Patch) > 0 { + patchOps := make([]patchOperation, 0) + if err := json.Unmarshal(c.response.Patch, &patchOps); err != nil { + t.Errorf("Failed to unmarshal patch: %v", err) + } + expectedPatchOps, _ = json.MarshalIndent(patchOps, "", " ") + } + assert.Equal(t, string(expectedPatchOps), string(actualPatchOps)) }) } } @@ -113,17 +138,19 @@ var rawPodWithoutVolume = []byte(` var validPatchIfNoVolumesPresent = []byte(`[{"op":"add","path":"/spec/volumes","value":[{"name":"aws-iam-token","projected":{"sources":[{"serviceAccountToken":{"audience":"sts.amazonaws.com","expirationSeconds":3600,"path":"token"}}]}}]},{"op":"add","path":"/spec/containers","value":[{"name":"balajilovesoreos","image":"amazonlinux","env":[{"name":"AWS_ROLE_ARN","value":"arn:aws:iam::111122223333:role/s3-reader"},{"name":"AWS_WEB_IDENTITY_TOKEN_FILE","value":"/var/run/secrets/eks.amazonaws.com/serviceaccount/token"}],"resources":{},"volumeMounts":[{"name":"aws-iam-token","readOnly":true,"mountPath":"/var/run/secrets/eks.amazonaws.com/serviceaccount"}]}]}]`) -var validHandlerResponse = &v1beta1.AdmissionResponse{ - UID: "918ef1dc-928f-4525-99ef-988389f263c3", - Allowed: true, - Patch: validPatchIfNoVolumesPresent, - PatchType: &jsonPatchType, +func getValidHandlerResponse(uuid string) *v1beta1.AdmissionResponse { + return &v1beta1.AdmissionResponse{ + UID: types.UID(uuid), + Allowed: true, + Patch: validPatchIfNoVolumesPresent, + PatchType: &jsonPatchType, + } } func getValidReview(pod []byte) *v1beta1.AdmissionReview { return &v1beta1.AdmissionReview{ Request: &v1beta1.AdmissionRequest{ - UID: "918ef1dc-928f-4525-99ef-988389f263c3", + UID: uuid, Kind: metav1.GroupVersionKind{ Version: "v1", Kind: "Pod", @@ -216,7 +243,7 @@ func TestModifierHandler(t *testing.T) { "ValidRequestSuccessWithoutVolumes", serializeAdmissionReview(t, getValidReview(rawPodWithoutVolume)), "application/json", - serializeAdmissionReview(t, &v1beta1.AdmissionReview{Response: validHandlerResponse}), + serializeAdmissionReview(t, &v1beta1.AdmissionReview{Response: getValidHandlerResponse(uuid)}), }, }