Skip to content

Commit ad0e80b

Browse files
committed
Add nodeLabeller interface for testing
Signed-off-by: Evan Lezar <[email protected]>
1 parent 0bbed4f commit ad0e80b

File tree

4 files changed

+244
-33
lines changed

4 files changed

+244
-33
lines changed

pkg/mig/reconfigure/api.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ type migParted interface {
3030
applyMIGModeOnly() error
3131
applyMIGConfig() error
3232
}
33+
34+
// nodeLabeller defines an interface for interacting with node labels.
35+
//
36+
//go:generate moq -rm -fmt=goimports -out node-labeller_mock.go . nodeLabeller
37+
type nodeLabeller interface {
38+
getNodeLabelValue(string) (string, error)
39+
setNodeLabelValue(string, string) error
40+
}

pkg/mig/reconfigure/node-labeller_mock.go

Lines changed: 124 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/mig/reconfigure/reconfigure.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
log "github.com/sirupsen/logrus"
1414
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
15+
"k8s.io/client-go/kubernetes"
1516
)
1617

1718
const (
@@ -24,8 +25,9 @@ const (
2425

2526
type reconfigurer struct {
2627
*reconfigureMIGOptions
27-
migParted migParted
2828
commandRunner
29+
migParted migParted
30+
node nodeLabeller
2931
}
3032

3133
// A commandWithOutput runs a command and ensures that STDERR and STDOUT are
@@ -57,6 +59,10 @@ func New(opts ...Option) (Reconfigurer, error) {
5759
DriverLibraryPath: o.DriverLibraryPath,
5860
commandRunner: c,
5961
},
62+
node: &node{
63+
clientset: o.clientset,
64+
name: o.NodeName,
65+
},
6066
}
6167

6268
return r, nil
@@ -73,12 +79,12 @@ func (opts *reconfigurer) Reconfigure() error {
7379
return fmt.Errorf("error validating the selected MIG configuration: %w", err)
7480
}
7581

76-
log.Infof("Getting current value of the '%s' node label", opts.configStateLabel)
77-
state, err := opts.getNodeLabelValue(opts.configStateLabel)
82+
log.Infof("Getting current value of the '%s' node label", opts.ConfigStateLabel)
83+
state, err := opts.node.getNodeLabelValue(opts.ConfigStateLabel)
7884
if err != nil {
79-
return fmt.Errorf("unable to get the value of the %q label: %w", opts.configStateLabel, err)
85+
return fmt.Errorf("unable to get the value of the %q label: %w", opts.ConfigStateLabel, err)
8086
}
81-
log.Infof("Current value of '%s=%s'", opts.configStateLabel, state)
87+
log.Infof("Current value of '%s=%s'", opts.ConfigStateLabel, state)
8288

8389
log.Info("Checking if the selected MIG config is currently applied or not")
8490
if err := opts.migParted.assertMIGConfig(); err == nil {
@@ -124,11 +130,11 @@ func (opts *reconfigurer) Reconfigure() error {
124130
log.Info("If the -r option was passed, the node will be automatically rebooted if this is not successful")
125131
if err := opts.migParted.applyMIGModeOnly(); err != nil || opts.migParted.assertMIGModeOnly() != nil {
126132
if opts.WithReboot {
127-
log.Infof("Changing the '%s' node label to '%s'", opts.configStateLabel, configStateRebooting)
128-
if err := opts.setNodeLabelValue(opts.configStateLabel, configStateRebooting); err != nil {
129-
log.Errorf("Unable to set the value of '%s' to '%s'", opts.configStateLabel, configStateRebooting)
133+
log.Infof("Changing the '%s' node label to '%s'", opts.ConfigStateLabel, configStateRebooting)
134+
if err := opts.node.setNodeLabelValue(opts.ConfigStateLabel, configStateRebooting); err != nil {
135+
log.Errorf("Unable to set the value of '%s' to '%s'", opts.ConfigStateLabel, configStateRebooting)
130136
log.Error("Exiting so as not to reboot multiple times unexpectedly")
131-
return fmt.Errorf("unable to set the value of %q to %q: %w", opts.configStateLabel, configStateRebooting, err)
137+
return fmt.Errorf("unable to set the value of %q to %q: %w", opts.ConfigStateLabel, configStateRebooting, err)
132138
}
133139
return rebootHost(opts.HostRootMount)
134140
}
@@ -376,8 +382,13 @@ func rebootHost(hostRootMount string) error {
376382
return nil
377383
}
378384

379-
func (opts *reconfigureMIGOptions) getNodeLabelValue(label string) (string, error) {
380-
node, err := opts.clientset.CoreV1().Nodes().Get(context.TODO(), opts.NodeName, metav1.GetOptions{})
385+
type node struct {
386+
clientset *kubernetes.Clientset
387+
name string
388+
}
389+
390+
func (n *node) getNodeLabelValue(label string) (string, error) {
391+
node, err := n.clientset.CoreV1().Nodes().Get(context.TODO(), n.name, metav1.GetOptions{})
381392
if err != nil {
382393
return "", fmt.Errorf("unable to get node object: %w", err)
383394
}
@@ -390,16 +401,16 @@ func (opts *reconfigureMIGOptions) getNodeLabelValue(label string) (string, erro
390401
return value, nil
391402
}
392403

393-
func (opts *reconfigureMIGOptions) setNodeLabelValue(label, value string) error {
394-
node, err := opts.clientset.CoreV1().Nodes().Get(context.TODO(), opts.NodeName, metav1.GetOptions{})
404+
func (n *node) setNodeLabelValue(label, value string) error {
405+
node, err := n.clientset.CoreV1().Nodes().Get(context.TODO(), n.name, metav1.GetOptions{})
395406
if err != nil {
396407
return fmt.Errorf("unable to get node object: %w", err)
397408
}
398409

399410
labels := node.GetLabels()
400411
labels[label] = value
401412
node.SetLabels(labels)
402-
_, err = opts.clientset.CoreV1().Nodes().Update(context.TODO(), node, metav1.UpdateOptions{})
413+
_, err = n.clientset.CoreV1().Nodes().Update(context.TODO(), node, metav1.UpdateOptions{})
403414
if err != nil {
404415
return fmt.Errorf("unable to update node object: %w", err)
405416
}

pkg/mig/reconfigure/reconfigure_test.go

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,36 @@ func (c *commandRunnerWithCLI) Run(cmd *exec.Cmd) error {
3535
return c.mock.Run(cmd)
3636
}
3737

38+
type nodeWithLabels struct {
39+
mock *nodeLabellerMock
40+
setLabels map[string]string
41+
}
42+
43+
func (n *nodeWithLabels) getNodeLabelValue(label string) (string, error) {
44+
return n.mock.getNodeLabelValue(label)
45+
}
46+
47+
func (n *nodeWithLabels) setNodeLabelValue(label string, value string) error {
48+
if err := n.mock.setNodeLabelValue(label, value); err != nil {
49+
return err
50+
}
51+
if n.setLabels == nil {
52+
n.setLabels = make(map[string]string)
53+
}
54+
n.setLabels[label] = value
55+
return nil
56+
}
57+
3858
func TestReconfigure(t *testing.T) {
3959
testCases := []struct {
40-
description string
41-
options reconfigureMIGOptions
42-
commandRunner *commandRunnerWithCLI
43-
migParted *migPartedMock
44-
checkMigParted func(*migPartedMock)
45-
expectedError error
46-
expectedCalls [][]string
60+
description string
61+
options reconfigureMIGOptions
62+
migParted *migPartedMock
63+
checkMigParted func(*migPartedMock)
64+
nodeLabeller *nodeWithLabels
65+
checkNodeLabeller func(*nodeWithLabels)
66+
expectedError error
67+
expectedCalls [][]string
4768
}{
4869
{
4970
description: "mig assert valid config failure does not call commands",
@@ -53,13 +74,7 @@ func TestReconfigure(t *testing.T) {
5374
SelectedMIGConfig: "selected-mig-config",
5475
DriverLibraryPath: "/path/to/libnvidia-ml.so.1",
5576
HostRootMount: "/host/",
56-
},
57-
commandRunner: &commandRunnerWithCLI{
58-
mock: &commandRunnerMock{
59-
RunFunc: func(cmd *exec.Cmd) error {
60-
return fmt.Errorf("error running command %v", cmd.Path)
61-
},
62-
},
77+
ConfigStateLabel: "example.com/config.state",
6378
},
6479
migParted: &migPartedMock{
6580
assertValidMIGConfigFunc: func() error {
@@ -76,9 +91,53 @@ func TestReconfigure(t *testing.T) {
7691
expectedError: fmt.Errorf("error validating the selected MIG configuration: invalid mig config"),
7792
expectedCalls: nil,
7893
},
94+
{
95+
description: "node label error is causes exit",
96+
options: reconfigureMIGOptions{
97+
NodeName: "NodeName",
98+
MIGPartedConfigFile: "/path/to/config/file.yaml",
99+
SelectedMIGConfig: "selected-mig-config",
100+
DriverLibraryPath: "/path/to/libnvidia-ml.so.1",
101+
HostRootMount: "/host/",
102+
ConfigStateLabel: "example.com/config.state",
103+
},
104+
migParted: &migPartedMock{
105+
assertValidMIGConfigFunc: func() error {
106+
return nil
107+
},
108+
},
109+
checkMigParted: func(mpm *migPartedMock) {
110+
require.Len(t, mpm.calls.assertValidMIGConfig, 1)
111+
require.Len(t, mpm.calls.applyMIGConfig, 0)
112+
require.Len(t, mpm.calls.assertMIGModeOnly, 0)
113+
require.Len(t, mpm.calls.applyMIGModeOnly, 0)
114+
require.Len(t, mpm.calls.applyMIGConfig, 0)
115+
},
116+
nodeLabeller: &nodeWithLabels{
117+
mock: &nodeLabellerMock{
118+
getNodeLabelValueFunc: func(s string) (string, error) {
119+
return "", fmt.Errorf("error getting label")
120+
},
121+
},
122+
},
123+
checkNodeLabeller: func(nwl *nodeWithLabels) {
124+
calls := nwl.mock.getNodeLabelValueCalls()
125+
require.Len(t, calls, 1)
126+
require.EqualValues(t, []struct{ S string }{{"example.com/config.state"}}, calls)
127+
},
128+
expectedError: fmt.Errorf(`unable to get the value of the "example.com/config.state" label: error getting label`),
129+
},
79130
}
80131

81132
for _, tc := range testCases {
133+
commandRunner := &commandRunnerWithCLI{
134+
mock: &commandRunnerMock{
135+
RunFunc: func(cmd *exec.Cmd) error {
136+
return fmt.Errorf("error running command %v", cmd.Path)
137+
},
138+
},
139+
}
140+
82141
t.Run(tc.description, func(t *testing.T) {
83142
// TODO: Once we have better mocks in place for the following
84143
// functionality, we can update this.
@@ -91,17 +150,26 @@ func TestReconfigure(t *testing.T) {
91150

92151
r := &reconfigurer{
93152
reconfigureMIGOptions: &tc.options,
94-
commandRunner: tc.commandRunner,
153+
commandRunner: commandRunner,
95154
migParted: tc.migParted,
155+
node: tc.nodeLabeller,
96156
}
97157

98158
err := r.Reconfigure()
99-
require.EqualValues(t, tc.expectedError.Error(), err.Error())
100-
101-
tc.checkMigParted(tc.migParted)
159+
if tc.expectedError == nil {
160+
require.NoError(t, err)
161+
} else {
162+
require.EqualError(t, err, tc.expectedError.Error())
163+
}
102164

103-
require.EqualValues(t, tc.expectedCalls, tc.commandRunner.calls)
165+
if tc.checkMigParted != nil {
166+
tc.checkMigParted(tc.migParted)
167+
}
168+
if tc.checkNodeLabeller != nil {
169+
tc.checkNodeLabeller(tc.nodeLabeller)
170+
}
104171

172+
require.EqualValues(t, tc.expectedCalls, commandRunner.calls)
105173
})
106174
}
107175
}

0 commit comments

Comments
 (0)