Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions internal/config/image/cuda_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package image
import (
"fmt"
"path/filepath"
"slices"
"strconv"
"strings"

Expand Down Expand Up @@ -143,8 +144,8 @@ func (i CUDA) HasDisableRequire() bool {
return false
}

// DevicesFromEnvvars returns the devices requested by the image through environment variables
func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
// devicesFromEnvvars returns the devices requested by the image through environment variables
func (i CUDA) devicesFromEnvvars(envVars ...string) []string {
// We concantenate all the devices from the specified env.
var isSet bool
var devices []string
Expand All @@ -165,15 +166,15 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {

// Environment variable unset with legacy image: default to "all".
if !isSet && len(devices) == 0 && i.IsLegacy() {
return NewVisibleDevices("all")
devices = []string{"all"}
}

// Environment variable unset or empty or "void": return nil
if len(devices) == 0 || requested["void"] {
return NewVisibleDevices("void")
devices = []string{"void"}
}

return NewVisibleDevices(devices...)
return NewVisibleDevices(devices...).List()
}

// GetDriverCapabilities returns the requested driver capabilities.
Expand Down Expand Up @@ -232,6 +233,22 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
return hasCDIdevice
}

// visibleEnvVars returns the environment variables that are used to determine device visibility.
// It returns the preferred environment variables that are set, or NVIDIA_VISIBLE_DEVICES if none are set.
func (i CUDA) visibleEnvVars() []string {
var envVars []string
for _, envVar := range i.preferredVisibleDeviceEnvVars {
if !i.HasEnvvar(envVar) {
continue
}
envVars = append(envVars, envVar)
}
if len(envVars) > 0 {
return envVars
}
return []string{EnvVarNvidiaVisibleDevices}
}

// VisibleDevices returns a list of devices requested in the container image.
// If volume mount requests are enabled these are returned if requested,
// otherwise device requests through environment variables are considered.
Expand All @@ -253,7 +270,7 @@ func (i CUDA) VisibleDevices() []string {
}

// Get the Fallback to reading from the environment variable if privileges are correct
envVarDeviceRequests := i.VisibleDevicesFromEnvVar()
envVarDeviceRequests := i.visibleDevicesFromEnvVar()
if len(envVarDeviceRequests) == 0 {
return nil
}
Expand All @@ -265,7 +282,10 @@ func (i CUDA) VisibleDevices() []string {
}

// We log a warning if we are ignoring the environment variable requests.
i.logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES in unprivileged container")
envVars := i.visibleEnvVars()
if len(envVars) > 0 {
i.logger.Warningf("Ignoring devices requested by environment variable(s) in unprivileged container: %v", envVars)
}

return nil
}
Expand All @@ -281,31 +301,34 @@ func (i CUDA) cdiDeviceRequestsFromAnnotations() []string {
return nil
}

var devices []string
for key, value := range i.annotations {
var annotationKeys []string
for key := range i.annotations {
for _, prefix := range i.annotationsPrefixes {
if strings.HasPrefix(key, prefix) {
devices = append(devices, strings.Split(value, ",")...)
annotationKeys = append(annotationKeys, key)
// There is no need to check additional prefixes since we
// typically deduplicate devices in any case.
break
}
}
}
// We sort the annotationKeys for consistent results.
slices.Sort(annotationKeys)

var devices []string
for _, key := range annotationKeys {
devices = append(devices, strings.Split(i.annotations[key], ",")...)
}
return devices
}

// VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
// visibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
// If any of the preferredVisibleDeviceEnvVars are present in the image, they
// are used to determine the visible devices. If this is not the case, the
// NVIDIA_VISIBLE_DEVICES environment variable is used.
func (i CUDA) VisibleDevicesFromEnvVar() []string {
for _, envVar := range i.preferredVisibleDeviceEnvVars {
if i.HasEnvvar(envVar) {
return i.DevicesFromEnvvars(i.preferredVisibleDeviceEnvVars...).List()
}
}
return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List()
func (i CUDA) visibleDevicesFromEnvVar() []string {
envVars := i.visibleEnvVars()
return i.devicesFromEnvvars(envVars...)
}

// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
Expand Down Expand Up @@ -391,7 +414,7 @@ func (m cdiDeviceMountRequest) qualifiedName() (string, error) {

// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromEnvVar() []string {
imexChannels := i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List()
imexChannels := i.devicesFromEnvvars(EnvVarNvidiaImexChannels)
if len(imexChannels) == 1 && imexChannels[0] == "all" {
return nil
}
Expand Down
102 changes: 88 additions & 14 deletions internal/config/image/cuda_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
)

require.NoError(t, err)
devices := image.VisibleDevicesFromEnvVar()
devices := image.visibleDevicesFromEnvVar()
require.EqualValues(t, tc.expectedDevices, devices)
})
}
Expand Down Expand Up @@ -508,13 +508,15 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {

func TestVisibleDevices(t *testing.T) {
var tests = []struct {
description string
mountDevices []specs.Mount
envvarDevices string
privileged bool
acceptUnprivileged bool
acceptMounts bool
expectedDevices []string
description string
mountDevices []specs.Mount
envvarDevices string
privileged bool
acceptUnprivileged bool
acceptMounts bool
preferredVisibleDeviceEnvVars []string
env map[string]string
expectedDevices []string
}{
{
description: "Mount devices, unprivileged, no accept unprivileged",
Expand Down Expand Up @@ -597,20 +599,92 @@ func TestVisibleDevices(t *testing.T) {
acceptMounts: false,
expectedDevices: nil,
},
// New test cases for visibleEnvVars functionality
{
description: "preferred env var set and present in env, privileged",
mountDevices: nil,
envvarDevices: "",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
},
expectedDevices: []string{"GPU-12345"},
},
{
description: "preferred env var set and present in env, unprivileged but accepted",
mountDevices: nil,
envvarDevices: "",
privileged: false,
acceptUnprivileged: true,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
},
expectedDevices: []string{"GPU-12345"},
},
{
description: "preferred env var set and present in env, unprivileged and not accepted",
mountDevices: nil,
envvarDevices: "",
privileged: false,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
},
expectedDevices: nil,
},
{
description: "multiple preferred env vars, both present, privileged",
mountDevices: nil,
envvarDevices: "",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
"DOCKER_RESOURCE_GPUS_ADDITIONAL": "GPU-67890",
},
expectedDevices: []string{"GPU-12345", "GPU-67890"},
},
{
description: "preferred env var not present, fallback to NVIDIA_VISIBLE_DEVICES, privileged",
mountDevices: nil,
envvarDevices: "GPU-12345",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
EnvVarNvidiaVisibleDevices: "GPU-12345",
},
expectedDevices: []string{"GPU-12345"},
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
// Wrap the call to getDevices() in a closure.
// Create env map with both NVIDIA_VISIBLE_DEVICES and any additional env vars
env := make(map[string]string)
if tc.envvarDevices != "" {
env[EnvVarNvidiaVisibleDevices] = tc.envvarDevices
}
for k, v := range tc.env {
env[k] = v
}

image, err := New(
WithEnvMap(
map[string]string{
EnvVarNvidiaVisibleDevices: tc.envvarDevices,
},
),
WithEnvMap(env),
WithMounts(tc.mountDevices),
WithPrivileged(tc.privileged),
WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
WithPreferredVisibleDevicesEnvVars(tc.preferredVisibleDeviceEnvVars...),
)
require.NoError(t, err)
require.Equal(t, tc.expectedDevices, image.VisibleDevices())
Expand Down
2 changes: 1 addition & 1 deletion internal/modifier/cdi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestDeviceRequests(t *testing.T) {
"another-prefix/bar": "example.com/device=baz",
},
},
expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"},
expectedDevices: []string{"example.com/device=baz", "example.com/device=bar"},
},
{
description: "multiple matching annotations with duplicate devices",
Expand Down
2 changes: 1 addition & 1 deletion internal/modifier/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
// NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
// The modifications are defined by CSV MountSpecs.
func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image.CUDA) (oci.SpecModifier, error) {
if devices := container.VisibleDevicesFromEnvVar(); len(devices) == 0 {
if devices := container.VisibleDevices(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested")
return nil, nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/modifier/gated.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import (
//
// If not devices are selected, no changes are made.
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
if devices := image.VisibleDevices(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested")
return nil, nil
}
Expand Down
20 changes: 11 additions & 9 deletions internal/modifier/graphics.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ import (

// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
if required, reason := requiresGraphicsModifier(containerImage); !required {
logger.Infof("No graphics modifier required: %v", reason)
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, container image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
devices, reason := requiresGraphicsModifier(container)
if len(devices) == 0 {
logger.Infof("No graphics modifier required; %v", reason)
return nil, nil
}

Expand All @@ -48,7 +49,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI
devRoot := driver.Root
drmNodes, err := discover.NewDRMNodesDiscoverer(
logger,
containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices),
image.NewVisibleDevices(devices...),
devRoot,
hookCreator,
)
Expand All @@ -64,14 +65,15 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI
}

// requiresGraphicsModifier determines whether a graphics modifier is required.
func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) {
if devices := cudaImage.VisibleDevicesFromEnvVar(); len(devices) == 0 {
return false, "no devices requested"
func requiresGraphicsModifier(cudaImage image.CUDA) ([]string, string) {
devices := cudaImage.VisibleDevices()
if len(devices) == 0 {
return nil, "no devices requested"
}

if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) {
return false, "no required capabilities requested"
return nil, "no required capabilities requested"
}

return true, ""
return devices, ""
}
Loading
Loading