Skip to content

Commit 75633e2

Browse files
committed
Deduplicate requested device IDs
This change ensures that the incoming device IDs are deduplicated before updating the AllocateResponse. This avoids cases where the NVIDIA_VISIBLE_DEVICES envvar or CDI annotations contain repeated device UUIDs or INDICES that do not add additional modifications to the container. Signed-off-by: Evan Lezar <[email protected]>
1 parent 521a122 commit 75633e2

File tree

2 files changed

+97
-3
lines changed

2 files changed

+97
-3
lines changed

internal/plugin/server.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ func (plugin *nvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.
319319
}
320320

321321
func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*pluginapi.ContainerAllocateResponse, error) {
322-
deviceIDs := plugin.deviceIDsFromAnnotatedDeviceIDs(requestIds)
322+
deviceIDs := plugin.uniqueDeviceIDsFromAnnotatedDeviceIDs(requestIds)
323323

324324
// Create an empty response that will be updated as required below.
325325
response := &pluginapi.ContainerAllocateResponse{
@@ -451,15 +451,24 @@ func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Durat
451451
return c, nil
452452
}
453453

454-
func (plugin *nvidiaDevicePlugin) deviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
454+
func (plugin *nvidiaDevicePlugin) uniqueDeviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
455455
var deviceIDs []string
456456
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyUUID {
457457
deviceIDs = rm.AnnotatedIDs(ids).GetIDs()
458458
}
459459
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyIndex {
460460
deviceIDs = plugin.rm.Devices().Subset(ids).GetIndices()
461461
}
462-
return deviceIDs
462+
var uniqueIDs []string
463+
seen := make(map[string]bool)
464+
for _, id := range deviceIDs {
465+
if seen[id] {
466+
continue
467+
}
468+
seen[id] = true
469+
uniqueIDs = append(uniqueIDs, id)
470+
}
471+
return uniqueIDs
463472
}
464473

465474
func (plugin *nvidiaDevicePlugin) apiDevices() []*pluginapi.Device {

internal/plugin/server_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package plugin
1818

1919
import (
20+
"context"
2021
"testing"
2122

2223
"github.com/stretchr/testify/require"
@@ -25,8 +26,88 @@ import (
2526
v1 "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
2627
"github.com/NVIDIA/k8s-device-plugin/internal/cdi"
2728
"github.com/NVIDIA/k8s-device-plugin/internal/imex"
29+
"github.com/NVIDIA/k8s-device-plugin/internal/rm"
2830
)
2931

32+
func TestAllocate(t *testing.T) {
33+
testCases := []struct {
34+
description string
35+
request *pluginapi.AllocateRequest
36+
expectedError error
37+
expectedResponse *pluginapi.AllocateResponse
38+
}{
39+
{
40+
description: "single device",
41+
request: &pluginapi.AllocateRequest{
42+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
43+
{
44+
DevicesIDs: []string{"foo"},
45+
},
46+
},
47+
},
48+
expectedResponse: &pluginapi.AllocateResponse{
49+
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
50+
{
51+
Envs: map[string]string{
52+
"NVIDIA_VISIBLE_DEVICES": "foo",
53+
},
54+
},
55+
},
56+
},
57+
},
58+
{
59+
description: "duplicate device IDs",
60+
request: &pluginapi.AllocateRequest{
61+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
62+
{
63+
DevicesIDs: []string{"foo", "bar", "foo"},
64+
},
65+
},
66+
},
67+
expectedResponse: &pluginapi.AllocateResponse{
68+
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
69+
{
70+
Envs: map[string]string{
71+
"NVIDIA_VISIBLE_DEVICES": "foo,bar",
72+
},
73+
},
74+
},
75+
},
76+
},
77+
}
78+
79+
for _, tc := range testCases {
80+
t.Run(tc.description, func(t *testing.T) {
81+
plugin := nvidiaDevicePlugin{
82+
rm: &rm.ResourceManagerMock{
83+
ValidateRequestFunc: func(annotatedIDs rm.AnnotatedIDs) error {
84+
return nil
85+
},
86+
},
87+
config: &v1.Config{
88+
Flags: v1.Flags{
89+
CommandLineFlags: v1.CommandLineFlags{
90+
Plugin: &v1.PluginCommandLineFlags{
91+
DeviceIDStrategy: ptr(v1.DeviceIDStrategyUUID),
92+
},
93+
},
94+
},
95+
},
96+
cdiHandler: &cdi.InterfaceMock{
97+
QualifiedNameFunc: func(c string, s string) string {
98+
return "nvidia.com/" + c + "=" + s
99+
},
100+
},
101+
deviceListStrategies: v1.DeviceListStrategies{"envvar": true},
102+
}
103+
104+
response, err := plugin.Allocate(context.TODO(), tc.request)
105+
require.EqualValues(t, tc.expectedError, err)
106+
require.EqualValues(t, tc.expectedResponse, response)
107+
})
108+
}
109+
}
110+
30111
func TestCDIAllocateResponse(t *testing.T) {
31112
testCases := []struct {
32113
description string
@@ -169,3 +250,7 @@ func TestCDIAllocateResponse(t *testing.T) {
169250
})
170251
}
171252
}
253+
254+
func ptr[T any](x T) *T {
255+
return &x
256+
}

0 commit comments

Comments
 (0)