Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (plugin *nvidiaDevicePlugin) Allocate(ctx context.Context, reqs *pluginapi.
}

func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*pluginapi.ContainerAllocateResponse, error) {
deviceIDs := plugin.deviceIDsFromAnnotatedDeviceIDs(requestIds)
deviceIDs := plugin.uniqueDeviceIDsFromAnnotatedDeviceIDs(requestIds)

// Create an empty response that will be updated as required below.
response := &pluginapi.ContainerAllocateResponse{
Expand Down Expand Up @@ -348,16 +348,16 @@ func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyVolumeMounts) {
plugin.updateResponseForDeviceMounts(response, deviceIDs...)
}
if *plugin.config.Flags.Plugin.PassDeviceSpecs {
if plugin.config.Flags.Plugin.PassDeviceSpecs != nil && *plugin.config.Flags.Plugin.PassDeviceSpecs {
response.Devices = append(response.Devices, plugin.apiDeviceSpecs(*plugin.config.Flags.NvidiaDevRoot, requestIds)...)
}
if *plugin.config.Flags.GDRCopyEnabled {
if plugin.config.Flags.GDRCopyEnabled != nil && *plugin.config.Flags.GDRCopyEnabled {
response.Envs["NVIDIA_GDRCOPY"] = "enabled"
}
if *plugin.config.Flags.GDSEnabled {
if plugin.config.Flags.GDSEnabled != nil && *plugin.config.Flags.GDSEnabled {
response.Envs["NVIDIA_GDS"] = "enabled"
}
if *plugin.config.Flags.MOFEDEnabled {
if plugin.config.Flags.MOFEDEnabled != nil && *plugin.config.Flags.MOFEDEnabled {
response.Envs["NVIDIA_MOFED"] = "enabled"
}
return response, nil
Expand Down Expand Up @@ -451,15 +451,24 @@ func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Durat
return c, nil
}

func (plugin *nvidiaDevicePlugin) deviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
func (plugin *nvidiaDevicePlugin) uniqueDeviceIDsFromAnnotatedDeviceIDs(ids []string) []string {
var deviceIDs []string
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyUUID {
deviceIDs = rm.AnnotatedIDs(ids).GetIDs()
}
if *plugin.config.Flags.Plugin.DeviceIDStrategy == spec.DeviceIDStrategyIndex {
deviceIDs = plugin.rm.Devices().Subset(ids).GetIndices()
}
return deviceIDs
var uniqueIDs []string
seen := make(map[string]bool)
for _, id := range deviceIDs {
if seen[id] {
continue
}
seen[id] = true
uniqueIDs = append(uniqueIDs, id)
}
return uniqueIDs
}

func (plugin *nvidiaDevicePlugin) apiDevices() []*pluginapi.Device {
Expand Down
85 changes: 85 additions & 0 deletions internal/plugin/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package plugin

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -25,8 +26,88 @@ import (
v1 "github.com/NVIDIA/k8s-device-plugin/api/config/v1"
"github.com/NVIDIA/k8s-device-plugin/internal/cdi"
"github.com/NVIDIA/k8s-device-plugin/internal/imex"
"github.com/NVIDIA/k8s-device-plugin/internal/rm"
)

func TestAllocate(t *testing.T) {
testCases := []struct {
description string
request *pluginapi.AllocateRequest
expectedError error
expectedResponse *pluginapi.AllocateResponse
}{
{
description: "single device",
request: &pluginapi.AllocateRequest{
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
{
DevicesIDs: []string{"foo"},
},
},
},
expectedResponse: &pluginapi.AllocateResponse{
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
{
Envs: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "foo",
},
},
},
},
},
{
description: "duplicate device IDs",
request: &pluginapi.AllocateRequest{
ContainerRequests: []*pluginapi.ContainerAllocateRequest{
{
DevicesIDs: []string{"foo", "bar", "foo"},
},
},
},
expectedResponse: &pluginapi.AllocateResponse{
ContainerResponses: []*pluginapi.ContainerAllocateResponse{
{
Envs: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "foo,bar",
},
},
},
},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
plugin := nvidiaDevicePlugin{
rm: &rm.ResourceManagerMock{
ValidateRequestFunc: func(annotatedIDs rm.AnnotatedIDs) error {
return nil
},
},
config: &v1.Config{
Flags: v1.Flags{
CommandLineFlags: v1.CommandLineFlags{
Plugin: &v1.PluginCommandLineFlags{
DeviceIDStrategy: ptr(v1.DeviceIDStrategyUUID),
},
},
},
},
cdiHandler: &cdi.InterfaceMock{
QualifiedNameFunc: func(c string, s string) string {
return "nvidia.com/" + c + "=" + s
},
},
deviceListStrategies: v1.DeviceListStrategies{"envvar": true},
}

response, err := plugin.Allocate(context.TODO(), tc.request)
require.EqualValues(t, tc.expectedError, err)
require.EqualValues(t, tc.expectedResponse, response)
})
}
}

func TestCDIAllocateResponse(t *testing.T) {
testCases := []struct {
description string
Expand Down Expand Up @@ -169,3 +250,7 @@ func TestCDIAllocateResponse(t *testing.T) {
})
}
}

func ptr[T any](x T) *T {
return &x
}
2 changes: 2 additions & 0 deletions internal/rm/rm.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type resourceManager struct {
}

// ResourceManager provides an interface for listing a set of Devices and checking health on them
//
//go:generate moq -rm -fmt=goimports -stub -out rm_mock.go . ResourceManager
type ResourceManager interface {
Resource() spec.ResourceName
Devices() Devices
Expand Down
Loading