Skip to content

Commit 5aef9c1

Browse files
committed
Deduplicate requested device IDs
This change ensures that the incoming device IDs are deduplicated before updating the AllocateResponse. This avoids cases where the NVIDIA_VISIBLE_DEVICES envvar or CDI annotations contain repeated device UUIDs or INDICES that do not add additional modifications to the container. Signed-off-by: Evan Lezar <[email protected]>
1 parent b4317da commit 5aef9c1

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

internal/plugin/server.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ func (plugin *nvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.
319319
}
320320

321321
func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*pluginapi.ContainerAllocateResponse, error) {
322-
deviceIDs := plugin.deviceIDsFromAnnotatedDeviceIDs(requestIds)
322+
deviceIDs := plugin.uniqueDeviceIDsFromAnnotatedDeviceIDs(requestIds)
323323

324324
// Create an empty response that will be updated as required below.
325325
response := &pluginapi.ContainerAllocateResponse{
@@ -452,15 +452,24 @@ func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Durat
452452
return c, nil
453453
}
454454

455-
func (plugin *nvidiaDevicePlugin) deviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
455+
func (plugin *nvidiaDevicePlugin) uniqueDeviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
456456
var deviceIDs []string
457457
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyUUID {
458458
deviceIDs = rm.AnnotatedIDs(ids).GetIDs()
459459
}
460460
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyIndex {
461461
deviceIDs = plugin.rm.Devices().Subset(ids).GetIndices()
462462
}
463-
return deviceIDs
463+
var uniqueIDs []string
464+
seen := make(map[string]bool)
465+
for _, id := range deviceIDs {
466+
if seen[id] {
467+
continue
468+
}
469+
seen[id] = true
470+
uniqueIDs = append(uniqueIDs, id)
471+
}
472+
return uniqueIDs
464473
}
465474

466475
func (plugin *nvidiaDevicePlugin) apiDevices() []*pluginapi.Device {

internal/plugin/server_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package plugin
1818

1919
import (
20+
"context"
2021
"testing"
2122

2223
"github.com/stretchr/testify/require"
@@ -25,8 +26,88 @@ import (
2526
v1 "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
2627
"github.com/NVIDIA/k8s-device-plugin/internal/cdi"
2728
"github.com/NVIDIA/k8s-device-plugin/internal/imex"
29+
"github.com/NVIDIA/k8s-device-plugin/internal/rm"
2830
)
2931

32+
func TestAllocate(t *testing.T) {
33+
testCases := []struct {
34+
description string
35+
request *pluginapi.AllocateRequest
36+
expectedError error
37+
expectedResponse *pluginapi.AllocateResponse
38+
}{
39+
{
40+
description: "single device",
41+
request: &pluginapi.AllocateRequest{
42+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
43+
{
44+
DevicesIDs: []string{"foo"},
45+
},
46+
},
47+
},
48+
expectedResponse: &pluginapi.AllocateResponse{
49+
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
50+
{
51+
Envs: map[string]string{
52+
"NVIDIA_VISIBLE_DEVICES": "foo",
53+
},
54+
},
55+
},
56+
},
57+
},
58+
{
59+
description: "duplicate device IDs",
60+
request: &pluginapi.AllocateRequest{
61+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
62+
{
63+
DevicesIDs: []string{"foo", "bar", "foo"},
64+
},
65+
},
66+
},
67+
expectedResponse: &pluginapi.AllocateResponse{
68+
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
69+
{
70+
Envs: map[string]string{
71+
"NVIDIA_VISIBLE_DEVICES": "foo,bar",
72+
},
73+
},
74+
},
75+
},
76+
},
77+
}
78+
79+
for _, tc := range testCases {
80+
t.Run(tc.description, func(t *testing.T) {
81+
plugin := nvidiaDevicePlugin{
82+
rm: &rm.ResourceManagerMock{
83+
ValidateRequestFunc: func(annotatedIDs rm.AnnotatedIDs) error {
84+
return nil
85+
},
86+
},
87+
config: &v1.Config{
88+
Flags: v1.Flags{
89+
CommandLineFlags: v1.CommandLineFlags{
90+
Plugin: &v1.PluginCommandLineFlags{
91+
DeviceIDStrategy: ptr(v1.DeviceIDStrategyUUID),
92+
},
93+
},
94+
},
95+
},
96+
cdiHandler: &cdi.InterfaceMock{
97+
QualifiedNameFunc: func(c string, s string) string {
98+
return "nvidia.com/" + c + "=" + s
99+
},
100+
},
101+
deviceListStrategies: v1.DeviceListStrategies{"envvar": true},
102+
}
103+
104+
response, err := plugin.Allocate(context.TODO(), tc.request)
105+
require.EqualValues(t, tc.expectedError, err)
106+
require.EqualValues(t, tc.expectedResponse, response)
107+
})
108+
}
109+
}
110+
30111
func TestCDIAllocateResponse(t *testing.T) {
31112
testCases := []struct {
32113
description string
@@ -166,3 +247,7 @@ func TestCDIAllocateResponse(t *testing.T) {
166247
})
167248
}
168249
}
250+
251+
func ptr[T any](x T) *T {
252+
return &x
253+
}

0 commit comments

Comments
 (0)