Skip to content

Commit 395b3d5

Browse files
authored
Merge pull request #1434 from elezar/deduplicate-device-requests
Deduplicate requested device IDs
2 parents dd6edf8 + 75633e2 commit 395b3d5

File tree

4 files changed

+422
-7
lines changed

4 files changed

+422
-7
lines changed

internal/plugin/server.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ func (plugin *nvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.
319319
}
320320

321321
func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*pluginapi.ContainerAllocateResponse, error) {
322-
deviceIDs := plugin.deviceIDsFromAnnotatedDeviceIDs(requestIds)
322+
deviceIDs := plugin.uniqueDeviceIDsFromAnnotatedDeviceIDs(requestIds)
323323

324324
// Create an empty response that will be updated as required below.
325325
response := &pluginapi.ContainerAllocateResponse{
@@ -348,16 +348,16 @@ func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
348348
if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyVolumeMounts) {
349349
plugin.updateResponseForDeviceMounts(response, deviceIDs...)
350350
}
351-
if *plugin.config.Flags.Plugin.PassDeviceSpecs {
351+
if plugin.config.Flags.Plugin.PassDeviceSpecs != nil && *plugin.config.Flags.Plugin.PassDeviceSpecs {
352352
response.Devices = append(response.Devices, plugin.apiDeviceSpecs(*plugin.config.Flags.NvidiaDevRoot, requestIds)...)
353353
}
354-
if *plugin.config.Flags.GDRCopyEnabled {
354+
if plugin.config.Flags.GDRCopyEnabled != nil && *plugin.config.Flags.GDRCopyEnabled {
355355
response.Envs["NVIDIA_GDRCOPY"] = "enabled"
356356
}
357-
if *plugin.config.Flags.GDSEnabled {
357+
if plugin.config.Flags.GDSEnabled != nil && *plugin.config.Flags.GDSEnabled {
358358
response.Envs["NVIDIA_GDS"] = "enabled"
359359
}
360-
if *plugin.config.Flags.MOFEDEnabled {
360+
if plugin.config.Flags.MOFEDEnabled != nil && *plugin.config.Flags.MOFEDEnabled {
361361
response.Envs["NVIDIA_MOFED"] = "enabled"
362362
}
363363
return response, nil
@@ -451,15 +451,24 @@ func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Durat
451451
return c, nil
452452
}
453453

454-
func (plugin *nvidiaDevicePlugin) deviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
454+
func (plugin *nvidiaDevicePlugin) uniqueDeviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
455455
var deviceIDs []string
456456
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyUUID {
457457
deviceIDs = rm.AnnotatedIDs(ids).GetIDs()
458458
}
459459
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyIndex {
460460
deviceIDs = plugin.rm.Devices().Subset(ids).GetIndices()
461461
}
462-
return deviceIDs
462+
var uniqueIDs []string
463+
seen := make(map[string]bool)
464+
for _, id := range deviceIDs {
465+
if seen[id] {
466+
continue
467+
}
468+
seen[id] = true
469+
uniqueIDs = append(uniqueIDs, id)
470+
}
471+
return uniqueIDs
463472
}
464473

465474
func (plugin *nvidiaDevicePlugin) apiDevices() []*pluginapi.Device {

internal/plugin/server_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package plugin
1818

1919
import (
20+
"context"
2021
"testing"
2122

2223
"github.com/stretchr/testify/require"
@@ -25,8 +26,88 @@ import (
2526
v1 "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
2627
"github.com/NVIDIA/k8s-device-plugin/internal/cdi"
2728
"github.com/NVIDIA/k8s-device-plugin/internal/imex"
29+
"github.com/NVIDIA/k8s-device-plugin/internal/rm"
2830
)
2931

32+
func TestAllocate(t *testing.T) {
33+
testCases := []struct {
34+
description string
35+
request *pluginapi.AllocateRequest
36+
expectedError error
37+
expectedResponse *pluginapi.AllocateResponse
38+
}{
39+
{
40+
description: "single device",
41+
request: &pluginapi.AllocateRequest{
42+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
43+
{
44+
DevicesIDs: []string{"foo"},
45+
},
46+
},
47+
},
48+
expectedResponse: &pluginapi.AllocateResponse{
49+
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
50+
{
51+
Envs: map[string]string{
52+
"NVIDIA_VISIBLE_DEVICES": "foo",
53+
},
54+
},
55+
},
56+
},
57+
},
58+
{
59+
description: "duplicate device IDs",
60+
request: &pluginapi.AllocateRequest{
61+
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
62+
{
63+
DevicesIDs: []string{"foo", "bar", "foo"},
64+
},
65+
},
66+
},
67+
expectedResponse: &pluginapi.AllocateResponse{
68+
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
69+
{
70+
Envs: map[string]string{
71+
"NVIDIA_VISIBLE_DEVICES": "foo,bar",
72+
},
73+
},
74+
},
75+
},
76+
},
77+
}
78+
79+
for _, tc := range testCases {
80+
t.Run(tc.description, func(t *testing.T) {
81+
plugin := nvidiaDevicePlugin{
82+
rm: &rm.ResourceManagerMock{
83+
ValidateRequestFunc: func(annotatedIDs rm.AnnotatedIDs) error {
84+
return nil
85+
},
86+
},
87+
config: &v1.Config{
88+
Flags: v1.Flags{
89+
CommandLineFlags: v1.CommandLineFlags{
90+
Plugin: &v1.PluginCommandLineFlags{
91+
DeviceIDStrategy: ptr(v1.DeviceIDStrategyUUID),
92+
},
93+
},
94+
},
95+
},
96+
cdiHandler: &cdi.InterfaceMock{
97+
QualifiedNameFunc: func(c string, s string) string {
98+
return "nvidia.com/" + c + "=" + s
99+
},
100+
},
101+
deviceListStrategies: v1.DeviceListStrategies{"envvar": true},
102+
}
103+
104+
response, err := plugin.Allocate(context.TODO(), tc.request)
105+
require.EqualValues(t, tc.expectedError, err)
106+
require.EqualValues(t, tc.expectedResponse, response)
107+
})
108+
}
109+
}
110+
30111
func TestCDIAllocateResponse(t *testing.T) {
31112
testCases := []struct {
32113
description string
@@ -169,3 +250,7 @@ func TestCDIAllocateResponse(t *testing.T) {
169250
})
170251
}
171252
}
253+
254+
func ptr[T any](x T) *T {
255+
return &x
256+
}

internal/rm/rm.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ type resourceManager struct {
3737
}
3838

3939
// ResourceManager provides an interface for listing a set of Devices and checking health on them
40+
//
41+
//go:generate moq -rm -fmt=goimports -stub -out rm_mock.go . ResourceManager
4042
type ResourceManager interface {
4143
Resource() spec.ResourceName
4244
Devices() Devices

0 commit comments

Comments
 (0)