From 91d1678b42ad587bcb0cb9af51530b754f06fca6 Mon Sep 17 00:00:00 2001 From: Eddie Torres Date: Wed, 19 Nov 2025 18:50:57 +0000 Subject: [PATCH] Fix metadata labeler test flakes Signed-off-by: Eddie Torres --- pkg/cloud/metadata/labels_test.go | 738 ++++++++++++++---------------- 1 file changed, 347 insertions(+), 391 deletions(-) diff --git a/pkg/cloud/metadata/labels_test.go b/pkg/cloud/metadata/labels_test.go index c75d8eba57..28855a969e 100644 --- a/pkg/cloud/metadata/labels_test.go +++ b/pkg/cloud/metadata/labels_test.go @@ -17,483 +17,439 @@ package metadata import ( "context" "errors" - "fmt" - reflect "reflect" - "strconv" - "sync" + "slices" + "strings" "testing" - "time" "github.com/aws/aws-sdk-go-v2/service/ec2/types" - gomock "github.com/golang/mock/gomock" + "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/util" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes/fake" - clienttesting "k8s.io/client-go/testing" "k8s.io/client-go/tools/cache" ) func init() { - // Ensure variables are initialized - // TODO: Figure out a cleaner way to do this in tests initVariables() } -func TestPatchNewNodes(t *testing.T) { - testCases := []struct { - name string - newNode *corev1.Node - newPV *corev1.PersistentVolume - newNodeMetadata map[string]enisVolumes - expErr error +func TestGetMetadata(t *testing.T) { + tests := []struct { + name string + nodes []corev1.Node + pvs []corev1.PersistentVolume + instances []*types.Instance + cloudErr error + want map[string]enisVolumes + wantErr bool }{ { - name: "success: normal, new node added", - newNode: &corev1.Node{ - ObjectMeta: metav1.ObjectMeta{ - Name: "i-001", - Labels: make(map[string]string), - }, - Spec: corev1.NodeSpec{ - ProviderID: "example/i-001", - }, + name: "single node with volumes and ENIs", + nodes: []corev1.Node{ + makeNode("i-001", "aws:///us-west-2a/i-001"), + }, + instances: []*types.Instance{ + makeInstance("i-001", 2, []string{"vol-001", "vol-002"}), + }, + want: map[string]enisVolumes{ + "i-001": {ENIs: 2, Volumes: 1}, + }, + }, + { + name: "multiple nodes", + nodes: []corev1.Node{ + makeNode("i-001", "aws:///us-west-2a/i-001"), + makeNode("i-002", "aws:///us-west-2b/i-002"), + }, + instances: []*types.Instance{ + makeInstance("i-001", 1, []string{"vol-001"}), + makeInstance("i-002", 3, []string{"vol-002", "vol-003", "vol-004"}), + }, + want: map[string]enisVolumes{ + "i-001": {ENIs: 1, Volumes: 0}, + "i-002": {ENIs: 3, Volumes: 2}, + }, + }, + { + name: "exclude CSI managed volumes", + nodes: []corev1.Node{ + makeNode("i-001", "aws:///us-west-2a/i-001"), + }, + pvs: []corev1.PersistentVolume{ + makeCSIPV("pv-001", "vol-001"), + }, + instances: []*types.Instance{ + makeInstance("i-001", 1, []string{"vol-001", "vol-002"}), + }, + want: map[string]enisVolumes{ + "i-001": {ENIs: 1, Volumes: 0}, + }, + }, + { + name: "exclude migrated volumes", + nodes: []corev1.Node{ + makeNode("i-001", "aws:///us-west-2a/i-001"), + }, + pvs: []corev1.PersistentVolume{ + makeMigratedPV("pv-001", "vol-001"), + }, + instances: []*types.Instance{ + makeInstance("i-001", 1, []string{"vol-001", "vol-002"}), + }, + want: map[string]enisVolumes{ + "i-001": {ENIs: 1, Volumes: 0}, }, - newNodeMetadata: map[string]enisVolumes{ - "i-001": {ENIs: 2, Volumes: 2}, + }, + { + name: "cloud error", + nodes: []corev1.Node{ + makeNode("i-001", "aws:///us-west-2a/i-001"), }, - expErr: nil, + cloudErr: errors.New("EC2 API error"), + wantErr: true, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - mockCtrl := gomock.NewController(t) - mockCloud := cloud.NewMockCloud(mockCtrl) - - mockCloud.EXPECT().GetInstancesPatching(gomock.Any(), gomock.Any()).Return( - []*types.Instance{newFakeInstance(tc.newNode.Name, tc.newNodeMetadata[tc.newNode.Name].ENIs, tc.newNodeMetadata[tc.newNode.Name].Volumes+1)}, - tc.expErr, - ) - - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - watcherStarted := make(chan struct{}) - mockClientSet := newMockClientSet(watcherStarted) - factory := informers.NewSharedInformerFactory(mockClientSet, 0) - pvInformer := factory.Core().V1().PersistentVolumes().Informer() - err := pvInformer.AddIndexers(cache.Indexers{ - "volumeID": volumeIDIndexFunc, - }) - if err != nil { - t.Fatalf("Failed to add volume ID indexer: %v", err) - } - nodesInformer := factory.Core().V1().Nodes().Informer() - patchError := patchNewNodes(ctx, mockClientSet, mockCloud, nodesInformer, pvInformer) - if patchError != nil { - if tc.expErr == nil { - t.Fatalf("MetadataInformer() failed: expected no error, got: %v", patchError) - } - if patchError.Error() != tc.expErr.Error() { - t.Fatalf("MetadataInformer() failed: expected error %q, got %q", tc.expErr, patchError) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCloud := cloud.NewMockCloud(ctrl) + nodeList := &corev1.NodeList{Items: tt.nodes} + expectedNodeIDs := make([]string, 0, len(tt.nodes)) + for _, node := range tt.nodes { + if id, err := parseProviderID(&node); err == nil && strings.HasPrefix(id, "i-") { + expectedNodeIDs = append(expectedNodeIDs, id) } } - - factory.Start(ctx.Done()) - cache.WaitForCacheSync(ctx.Done()) - <-watcherStarted - - _, err = mockClientSet.CoreV1().Nodes().Create(t.Context(), tc.newNode, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("error injecting node add: %v", err) - } - - // Mock k8s client is racy - var node *corev1.Node - start := time.Now() - timeout := 5 * time.Second - for time.Since(start) < timeout { - node, err = mockClientSet.CoreV1().Nodes().Get(t.Context(), tc.newNode.Name, metav1.GetOptions{}) - if err == nil && node.GetLabels()[ENIsLabel] != "" { - break - } - time.Sleep(100 * time.Millisecond) + if len(expectedNodeIDs) > 0 || tt.cloudErr != nil { + mockCloud.EXPECT().GetInstancesPatching(ctx, expectedNodeIDs). + Return(tt.instances, tt.cloudErr).Times(1) } - expectedENIs := strconv.Itoa(tc.newNodeMetadata[node.Name].ENIs) - expectedVol := strconv.Itoa(tc.newNodeMetadata[node.Name].Volumes) + pvInformer := setupPVInformer(t, tt.pvs) - labeledENIs := node.GetLabels()[ENIsLabel] - labeledVol := node.GetLabels()[VolumesLabel] + got, err := getMetadata(ctx, mockCloud, nodeList, pvInformer) - if labeledENIs != expectedENIs { - t.Fatalf("MetadataInformer() failed: expected %s ENIs, got %s", expectedENIs, labeledENIs) + if (err != nil) != tt.wantErr { + t.Errorf("getMetadata() error = %v, wantErr %v", err, tt.wantErr) + return } - if labeledVol != expectedVol { - t.Fatalf("MetadataInformer() failed: expected %s volumes, got %s", expectedVol, labeledVol) + if !tt.wantErr && !equalEnisVolumesMap(got, tt.want) { + t.Errorf("getMetadata() = %v, want %v", got, tt.want) } }) } } -func newMockClientSet(watcherStarted chan struct{}) *fake.Clientset { - mockClientSet := fake.NewSimpleClientset() - var once sync.Once - mockClientSet.PrependWatchReactor("*", func(action clienttesting.Action) (handled bool, ret watch.Interface, err error) { - gvr := action.GetResource() - ns := action.GetNamespace() - watch, err := mockClientSet.Tracker().Watch(gvr, ns) - if err != nil { - return false, nil, err - } - once.Do(func() { - close(watcherStarted) - }) - return true, watch, nil - }) - return mockClientSet -} - -func newFakeInstance(instanceID string, numENIs, numVolumes int) *types.Instance { - blockDevices := make([]types.InstanceBlockDeviceMapping, numVolumes) - for i := range numVolumes { - volumeID := fmt.Sprintf("vol-00%d", i+1) - blockDevices[i] = types.InstanceBlockDeviceMapping{ - Ebs: &types.EbsInstanceBlockDevice{ - VolumeId: &volumeID, +func TestPatchSingleNode(t *testing.T) { + tests := []struct { + name string + node corev1.Node + metadata map[string]enisVolumes + wantENIs string + wantVolumes string + wantErr bool + }{ + { + name: "patch node successfully", + node: makeNode("i-001", "aws:///us-west-2a/i-001"), + metadata: map[string]enisVolumes{ + "i-001": {ENIs: 3, Volumes: 5}, }, - } - } - - return &types.Instance{ - InstanceId: &instanceID, - BlockDeviceMappings: blockDevices, - NetworkInterfaces: make([]types.InstanceNetworkInterface, numENIs), + wantENIs: "3", + wantVolumes: "5", + }, + { + name: "invalid provider ID", + node: makeNode("i-001", "invalid"), + metadata: map[string]enisVolumes{ + "i-001": {ENIs: 1, Volumes: 1}, + }, + wantErr: true, + }, } -} -func mockAddPV(newPV *corev1.PersistentVolume, instances []*types.Instance) []*types.Instance { - if newPV == nil { - return instances - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + clientset := fake.NewSimpleClientset(&tt.node) - var volumeID string + err := patchSingleNode(ctx, tt.node, tt.metadata, clientset) - if newPV.Spec.CSI != nil && newPV.Spec.CSI.Driver == util.GetDriverName() { - volumeID = newPV.Spec.CSI.VolumeHandle - } else if newPV.Spec.AWSElasticBlockStore != nil { - volumeID = newPV.Spec.AWSElasticBlockStore.VolumeID - } + if (err != nil) != tt.wantErr { + t.Errorf("patchSingleNode() error = %v, wantErr %v", err, tt.wantErr) + return + } - instances[0].BlockDeviceMappings = append(instances[0].BlockDeviceMappings, - types.InstanceBlockDeviceMapping{ - Ebs: &types.EbsInstanceBlockDevice{ - VolumeId: &volumeID, - }, + if !tt.wantErr { + node, _ := clientset.CoreV1().Nodes().Get(ctx, tt.node.Name, metav1.GetOptions{}) + if got := node.Labels[ENIsLabel]; got != tt.wantENIs { + t.Errorf("ENIs label = %v, want %v", got, tt.wantENIs) + } + if got := node.Labels[VolumesLabel]; got != tt.wantVolumes { + t.Errorf("Volumes label = %v, want %v", got, tt.wantVolumes) + } + } }) - - return instances + } } -func TestGetMetadata(t *testing.T) { - defaultNode := &corev1.NodeList{Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-001", - }, - - Spec: corev1.NodeSpec{ - ProviderID: "example/i-001", - }}, - }} - - testCases := []struct { - name string - instances []*types.Instance - nodes *corev1.NodeList - expectedMetadata map[string]enisVolumes - newPV *corev1.PersistentVolume - expErr error +func TestPatchNodes(t *testing.T) { + tests := []struct { + name string + nodes []corev1.Node + metadata map[string]enisVolumes + patchFails int + wantErr bool }{ { - name: "success: normal with multiple instances", - instances: []*types.Instance{newFakeInstance("i-001", 1, 1), newFakeInstance("i-002", 2, 3)}, - nodes: &corev1.NodeList{Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-001", - }, - - Spec: corev1.NodeSpec{ - ProviderID: "example/i-001", - }}, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-002", - }, - - Spec: corev1.NodeSpec{ - ProviderID: "example/i-002", - }}, - }}, - expectedMetadata: map[string]enisVolumes{ - "i-001": {ENIs: 1, Volumes: 0}, - "i-002": {ENIs: 2, Volumes: 2}, + name: "patch all nodes successfully", + nodes: []corev1.Node{ + makeNode("i-001", "aws:///us-west-2a/i-001"), + makeNode("i-002", "aws:///us-west-2b/i-002"), + }, + metadata: map[string]enisVolumes{ + "i-001": {ENIs: 1, Volumes: 2}, + "i-002": {ENIs: 3, Volumes: 4}, }, - newPV: nil, - expErr: nil, + patchFails: 5, }, { - name: "success: normal with one instance", - instances: []*types.Instance{newFakeInstance("i-001", 5, 2)}, - nodes: defaultNode, - expectedMetadata: map[string]enisVolumes{ - "i-001": {ENIs: 5, Volumes: 1}, + name: "fail when too many errors", + nodes: []corev1.Node{ + makeNode("i-001", "invalid"), + makeNode("i-002", "invalid"), + makeNode("i-003", "invalid"), + makeNode("i-004", "invalid"), + makeNode("i-005", "invalid"), }, - newPV: nil, - expErr: nil, + metadata: map[string]enisVolumes{}, + patchFails: 5, + wantErr: true, }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + nodeList := &corev1.NodeList{Items: tt.nodes} + clientset := fake.NewSimpleClientset(nodeList) + + err := patchNodes(ctx, nodeList, tt.metadata, clientset, tt.patchFails) + + if (err != nil) != tt.wantErr { + t.Errorf("patchNodes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestVolumeIDIndexFunc(t *testing.T) { + tests := []struct { + name string + pv any + want []string + }{ { - name: "success: normal with one instance and add one non csi managed PV", - instances: []*types.Instance{newFakeInstance("i-001", 5, 2)}, - nodes: defaultNode, - expectedMetadata: map[string]enisVolumes{ - "i-001": {ENIs: 5, Volumes: 2}, - }, - newPV: &corev1.PersistentVolume{ - Spec: corev1.PersistentVolumeSpec{ - PersistentVolumeSource: corev1.PersistentVolumeSource{ - CSI: &corev1.CSIPersistentVolumeSource{ - Driver: "", - VolumeHandle: "vol-003", - }, - }, - }, - }, - expErr: nil, + name: "CSI volume", + pv: makeCSIPVPtr("pv-001", "vol-001"), + want: []string{"vol-001"}, }, { - name: "success: normal with one instance and add one csi managed PV", - instances: []*types.Instance{newFakeInstance("i-001", 5, 2)}, - nodes: defaultNode, - expectedMetadata: map[string]enisVolumes{ - "i-001": {ENIs: 5, Volumes: 1}, - }, - newPV: &corev1.PersistentVolume{ - Spec: corev1.PersistentVolumeSpec{ - PersistentVolumeSource: corev1.PersistentVolumeSource{ - CSI: &corev1.CSIPersistentVolumeSource{ - Driver: util.GetDriverName(), - VolumeHandle: "vol-003", - }, - }, - }, - }, - expErr: nil, + name: "migrated volume", + pv: makeMigratedPVPtr("pv-001", "vol-001"), + want: []string{"vol-001"}, }, { - name: "success: normal with one instance and add one migrated PV", - instances: []*types.Instance{newFakeInstance("i-001", 5, 2)}, - nodes: defaultNode, - expectedMetadata: map[string]enisVolumes{ - "i-001": {ENIs: 5, Volumes: 1}, - }, - newPV: &corev1.PersistentVolume{ - Spec: corev1.PersistentVolumeSpec{ - PersistentVolumeSource: corev1.PersistentVolumeSource{ - CSI: &corev1.CSIPersistentVolumeSource{ - Driver: "", - }, - AWSElasticBlockStore: &corev1.AWSElasticBlockStoreVolumeSource{ - VolumeID: "vol-003", - }, - }, - }, + name: "non-EBS volume", + pv: &corev1.PersistentVolume{ + Spec: corev1.PersistentVolumeSpec{}, }, - expErr: nil, + want: []string{}, }, { - name: "error: describe instances error", - instances: []*types.Instance{newFakeInstance("i-001", 5, 2)}, - nodes: defaultNode, - expectedMetadata: map[string]enisVolumes{}, - newPV: nil, - expErr: errors.New("failed to describe instances"), + name: "invalid object", + pv: "not a PV", + want: []string{}, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - - watcherStarted := make(chan struct{}) - mockClientSet := newMockClientSet(watcherStarted) - factory := informers.NewSharedInformerFactory(mockClientSet, 0) - pvInformer := factory.Core().V1().PersistentVolumes().Informer() - err := pvInformer.AddIndexers(cache.Indexers{ - "volumeID": volumeIDIndexFunc, - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := volumeIDIndexFunc(tt.pv) if err != nil { - t.Fatalf("Failed to add volume ID indexer: %v", err) - } - factory.Start(ctx.Done()) - cache.WaitForCacheSync(ctx.Done()) - <-watcherStarted - - if tc.newPV != nil { - _, err := mockClientSet.CoreV1().PersistentVolumes().Create(t.Context(), tc.newPV, metav1.CreateOptions{}) - if err != nil { - t.Fatalf("error injecting PV add: %v", err) - } - time.Sleep(500 * time.Millisecond) + t.Errorf("volumeIDIndexFunc() error = %v", err) + return } - - tc.instances = mockAddPV(tc.newPV, tc.instances) - mockCtrl := gomock.NewController(t) - mockCloud := cloud.NewMockCloud(mockCtrl) - - mockCloud.EXPECT().GetInstancesPatching(gomock.Any(), gomock.Any()).Return( - tc.instances, - tc.expErr, - ) - - ENIsVolumesMap, err := getMetadata(t.Context(), mockCloud, tc.nodes, pvInformer) - if err != nil { - if tc.expErr == nil { - t.Fatalf("GetMetadata() failed: expected no error, got: %v", err) - } - if err.Error() != tc.expErr.Error() { - t.Fatalf("GetMetadata() failed: expected error %q, got %q", tc.expErr, err) - } - } else { - if tc.expErr != nil { - t.Fatal("GetMetadata() failed: expected error, got nothing") - } - if !reflect.DeepEqual(ENIsVolumesMap, tc.expectedMetadata) { - t.Fatalf("GetMetadata() failed: expected %v, go: %v", tc.expectedMetadata, ENIsVolumesMap) - } + if !slices.Equal(got, tt.want) { + t.Errorf("volumeIDIndexFunc() = %v, want %v", got, tt.want) } - mockCtrl.Finish() }) } } -func TestPatchLabels(t *testing.T) { - testCases := []struct { - name string - nodes corev1.NodeList - ENIsVolumesMap map[string]enisVolumes - expErr error +func TestGetNonCSIManagedVolumes(t *testing.T) { + tests := []struct { + name string + pvs []corev1.PersistentVolume + volumes []types.InstanceBlockDeviceMapping + want int }{ { - name: "success: normal patching 1 node", - ENIsVolumesMap: map[string]enisVolumes{ - "i-001": {ENIs: 1, Volumes: 1}, + name: "all non-CSI volumes", + volumes: []types.InstanceBlockDeviceMapping{ + makeBlockDevice("vol-001"), + makeBlockDevice("vol-002"), }, - nodes: corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-001", - Labels: map[string]string{}, - }, - Spec: corev1.NodeSpec{ - ProviderID: "example/i-001", - }, - }, - }, - }, - expErr: nil, + want: 2, }, { - name: "success: normal patching 2 nodes", - ENIsVolumesMap: map[string]enisVolumes{ - "i-001": {ENIs: 1, Volumes: 1}, - "i-002": {ENIs: 2, Volumes: 3}, + name: "one CSI managed volume", + pvs: []corev1.PersistentVolume{ + makeCSIPV("pv-001", "vol-001"), }, - nodes: corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-001", - Labels: map[string]string{}, - }, - Spec: corev1.NodeSpec{ - ProviderID: "example/i-001", - }, - }, - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-002", - Labels: map[string]string{}, - }, - Spec: corev1.NodeSpec{ - ProviderID: "example/i-002", - }, - }, - }, + volumes: []types.InstanceBlockDeviceMapping{ + makeBlockDevice("vol-001"), + makeBlockDevice("vol-002"), }, - expErr: nil, + want: 1, }, { - name: "error: failed to patch 1 node", - ENIsVolumesMap: map[string]enisVolumes{ - "i-001": {ENIs: 1, Volumes: 1}, + name: "all CSI managed volumes", + pvs: []corev1.PersistentVolume{ + makeCSIPV("pv-001", "vol-001"), + makeCSIPV("pv-002", "vol-002"), + }, + volumes: []types.InstanceBlockDeviceMapping{ + makeBlockDevice("vol-001"), + makeBlockDevice("vol-002"), + }, + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pvInformer := setupPVInformer(t, tt.pvs) + got := getNonCSIManagedVolumes(pvInformer, tt.volumes) + if got != tt.want { + t.Errorf("getNonCSIManagedVolumes() = %v, want %v", got, tt.want) + } + }) + } +} + +func makeNode(name, providerID string) corev1.Node { + return corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: make(map[string]string), + }, + Spec: corev1.NodeSpec{ + ProviderID: providerID, + }, + } +} + +func makeInstance(id string, numENIs int, volumeIDs []string) *types.Instance { + blockDevices := make([]types.InstanceBlockDeviceMapping, len(volumeIDs)) + for i, volID := range volumeIDs { + volIDCopy := volID + blockDevices[i] = types.InstanceBlockDeviceMapping{ + Ebs: &types.EbsInstanceBlockDevice{ + VolumeId: &volIDCopy, }, - nodes: corev1.NodeList{ - Items: []corev1.Node{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "i-001", - Labels: map[string]string{}, - }, - Spec: corev1.NodeSpec{ - ProviderID: "", - }, - }, + } + } + + return &types.Instance{ + InstanceId: &id, + NetworkInterfaces: make([]types.InstanceNetworkInterface, numENIs), + BlockDeviceMappings: blockDevices, + } +} + +func makeCSIPV(name, volumeHandle string) corev1.PersistentVolume { + return corev1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{Name: name}, + Spec: corev1.PersistentVolumeSpec{ + PersistentVolumeSource: corev1.PersistentVolumeSource{ + CSI: &corev1.CSIPersistentVolumeSource{ + Driver: util.GetDriverName(), + VolumeHandle: volumeHandle, }, }, - expErr: errors.New("failed to patch 1 nodes"), }, } +} - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - clientset := fake.NewSimpleClientset(&tc.nodes) - err := patchNodes(t.Context(), &tc.nodes, tc.ENIsVolumesMap, clientset, 1) - if err != nil { - if tc.expErr == nil { - t.Fatalf("PatchNodes() failed: expected no error, got: %v", err) - } - if err.Error() != tc.expErr.Error() { - t.Fatalf("PatchNodes() failed: expected error %q, got %q", tc.expErr, err) - } - } else { - if tc.expErr != nil { - t.Fatal("PatchNodes() failed: expected error, got nothing") - } +func makeCSIPVPtr(name, volumeHandle string) *corev1.PersistentVolume { + pv := makeCSIPV(name, volumeHandle) + return &pv +} + +func makeMigratedPV(name, volumeID string) corev1.PersistentVolume { + return corev1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{Name: name}, + Spec: corev1.PersistentVolumeSpec{ + PersistentVolumeSource: corev1.PersistentVolumeSource{ + AWSElasticBlockStore: &corev1.AWSElasticBlockStoreVolumeSource{ + VolumeID: volumeID, + }, + }, + }, + } +} - for _, originalNode := range tc.nodes.Items { - node, _ := clientset.CoreV1().Nodes().Get(t.Context(), originalNode.Name, metav1.GetOptions{}) - expectedENIs := strconv.Itoa(tc.ENIsVolumesMap[originalNode.Name].ENIs) - gotENIs := node.GetLabels()[ENIsLabel] +func makeMigratedPVPtr(name, volumeID string) *corev1.PersistentVolume { + pv := makeMigratedPV(name, volumeID) + return &pv +} - expectedVolumes := strconv.Itoa(tc.ENIsVolumesMap[originalNode.Name].Volumes) - gotVolumes := node.GetLabels()[VolumesLabel] +func makeBlockDevice(volumeID string) types.InstanceBlockDeviceMapping { + volIDCopy := volumeID + return types.InstanceBlockDeviceMapping{ + Ebs: &types.EbsInstanceBlockDevice{ + VolumeId: &volIDCopy, + }, + } +} - if node.GetLabels()[ENIsLabel] != strconv.Itoa(tc.ENIsVolumesMap[originalNode.Name].ENIs) { - t.Fatalf("PatchNodes() failed: expected %q ENIs, got %q", expectedENIs, gotENIs) - } - if node.GetLabels()[VolumesLabel] != strconv.Itoa(tc.ENIsVolumesMap[originalNode.Name].Volumes) { - t.Fatalf("PatchNodes() failed: expected %q volumes, got %q", expectedVolumes, gotVolumes) - } - } - } - }) +func setupPVInformer(t *testing.T, pvs []corev1.PersistentVolume) cache.SharedIndexInformer { + t.Helper() + clientset := fake.NewSimpleClientset() + for i := range pvs { + _, err := clientset.CoreV1().PersistentVolumes().Create(context.Background(), &pvs[i], metav1.CreateOptions{}) + if err != nil { + t.Fatalf("Failed to create PV: %v", err) + } + } + factory := informers.NewSharedInformerFactory(clientset, 0) + pvInformer := factory.Core().V1().PersistentVolumes().Informer() + if err := pvInformer.AddIndexers(cache.Indexers{"volumeID": volumeIDIndexFunc}); err != nil { + t.Fatalf("Failed to add indexer: %v", err) + } + stopCh := make(chan struct{}) + t.Cleanup(func() { close(stopCh) }) + factory.Start(stopCh) + cache.WaitForCacheSync(stopCh, pvInformer.HasSynced) + return pvInformer +} + +func equalEnisVolumesMap(a, b map[string]enisVolumes) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if b[k] != v { + return false + } } + return true }