Skip to content

Commit dc87dcf

Browse files
committed
Make CDI device requests consistent with other methods
Following the refactoring of device request extraction, we can now make CDI device requests consistent with other methods. This change moves to using image.VisibleDevices instead of separate calls to CDIDevicesFromMounts and VisibleDevicesFromEnvVar. Signed-off-by: Evan Lezar <[email protected]>
1 parent f17d424 commit dc87dcf

File tree

5 files changed

+209
-77
lines changed

5 files changed

+209
-77
lines changed

internal/config/image/cuda_image.go

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,12 @@ type CUDA struct {
5656
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
5757
// The process environment is read (if present) to construc the CUDA Image.
5858
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
59+
if spec == nil {
60+
return New(opts...)
61+
}
62+
5963
var env []string
60-
if spec != nil && spec.Process != nil {
64+
if spec.Process != nil {
6165
env = spec.Process.Env
6266
}
6367

@@ -219,19 +223,12 @@ func parseMajorMinorVersion(version string) (string, error) {
219223
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
220224
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
221225
var hasCDIdevice bool
222-
for _, device := range i.VisibleDevicesFromEnvVar() {
226+
for _, device := range i.VisibleDevices() {
223227
if !parser.IsQualifiedName(device) {
224228
return false
225229
}
226230
hasCDIdevice = true
227231
}
228-
229-
for _, device := range i.DevicesFromMounts() {
230-
if !strings.HasPrefix(device, "cdi/") {
231-
return false
232-
}
233-
hasCDIdevice = true
234-
}
235232
return hasCDIdevice
236233
}
237234

@@ -309,20 +306,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
309306
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
310307
func (i CUDA) visibleDevicesFromMounts() []string {
311308
var devices []string
312-
for _, device := range i.DevicesFromMounts() {
309+
for _, device := range i.requestsFromMounts() {
313310
switch {
314-
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
315-
continue
316311
case strings.HasPrefix(device, volumeMountDevicePrefixImex):
317312
continue
313+
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
314+
name, err := cdiDeviceMountRequest(device).qualifiedName()
315+
if err != nil {
316+
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %v", device, err)
317+
continue
318+
}
319+
devices = append(devices, name)
320+
default:
321+
devices = append(devices, device)
318322
}
319-
devices = append(devices, device)
323+
320324
}
321325
return devices
322326
}
323327

324-
// DevicesFromMounts returns a list of device specified as mounts.
325-
func (i CUDA) DevicesFromMounts() []string {
328+
// requestsFromMounts returns a list of device specified as mounts.
329+
func (i CUDA) requestsFromMounts() []string {
326330
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
327331
seen := make(map[string]bool)
328332
var devices []string
@@ -354,23 +358,30 @@ func (i CUDA) DevicesFromMounts() []string {
354358
return devices
355359
}
356360

357-
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
358-
func (i CUDA) CDIDevicesFromMounts() []string {
359-
var devices []string
360-
for _, mountDevice := range i.DevicesFromMounts() {
361-
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
362-
continue
363-
}
364-
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
365-
if len(parts) != 3 {
366-
continue
367-
}
368-
vendor := parts[0]
369-
class := parts[1]
370-
device := parts[2]
371-
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
361+
// a cdiDeviceMountRequest represents a CDI device requests as a mount.
362+
// Here the host path /dev/null is mounted to a particular path in the container.
363+
// The container path has the form:
364+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
365+
// or
366+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
367+
type cdiDeviceMountRequest string
368+
369+
// qualifiedName returns the fully-qualified name of the CDI device.
370+
func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
371+
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
372+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
372373
}
373-
return devices
374+
375+
requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
376+
if parser.IsQualifiedName(requestedDevice) {
377+
return requestedDevice, nil
378+
}
379+
380+
parts := strings.SplitN(requestedDevice, "/", 3)
381+
if len(parts) != 3 {
382+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
383+
}
384+
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
374385
}
375386

376387
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
@@ -385,7 +396,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
385396
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
386397
func (i CUDA) ImexChannelsFromMounts() []string {
387398
var channels []string
388-
for _, mountDevice := range i.DevicesFromMounts() {
399+
for _, mountDevice := range i.requestsFromMounts() {
389400
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
390401
continue
391402
}

internal/config/image/cuda_image_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
487487
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
488488
},
489489
{
490-
description: "cdi devices are ignored",
491-
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
492-
expectedDevices: []string{"GPU0", "GPU1"},
490+
description: "cdi devices are included",
491+
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
492+
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
493493
},
494494
{
495495
description: "imex devices are ignored",

internal/info/auto_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
184184
expectedMode: "legacy",
185185
},
186186
{
187-
description: "cdi mount and non-CDI envvar resolves to legacy",
187+
description: "cdi mount and non-CDI envvar resolves to cdi",
188188
mode: "auto",
189189
envmap: map[string]string{
190190
"NVIDIA_VISIBLE_DEVICES": "0",
@@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
197197
"tegra": false,
198198
"nvgpu": false,
199199
},
200+
expectedMode: "cdi",
201+
},
202+
{
203+
description: "non-cdi mount and CDI envvar resolves to legacy",
204+
mode: "auto",
205+
envmap: map[string]string{
206+
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
207+
},
208+
mounts: []string{
209+
"/var/run/nvidia-container-devices/0",
210+
},
211+
info: map[string]bool{
212+
"nvml": true,
213+
"tegra": false,
214+
"nvgpu": false,
215+
},
200216
expectedMode: "legacy",
201217
},
202218
}
@@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
232248
image, _ := image.New(
233249
image.WithEnvMap(tc.envmap),
234250
image.WithMounts(mounts),
251+
image.WithAcceptDeviceListAsVolumeMounts(true),
252+
image.WithAcceptEnvvarUnprivileged(true),
235253
)
236254
mode := resolveMode(logger, tc.mode, image, properties)
237255
require.EqualValues(t, tc.expectedMode, mode)

internal/modifier/cdi.go

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ import (
3434
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
3535
// used to select the devices to include.
3636
func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
37-
devices, err := getDevicesFromImage(logger, cfg, image)
38-
if err != nil {
39-
return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err)
40-
}
37+
deviceRequestor := newCDIDeviceRequestor(
38+
logger,
39+
image,
40+
cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
41+
)
42+
devices := deviceRequestor.DeviceRequests()
4143
if len(devices) == 0 {
4244
logger.Debugf("No devices requested; no modification required.")
4345
return nil, nil
@@ -64,63 +66,59 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
6466
)
6567
}
6668

67-
func getDevicesFromImage(logger logger.Interface, cfg *config.Config, container image.CUDA) ([]string, error) {
68-
annotationDevices, err := getAnnotationDevices(container)
69+
type deviceRequestor interface {
70+
DeviceRequests() []string
71+
}
72+
73+
type cdiDeviceRequestor struct {
74+
image image.CUDA
75+
logger logger.Interface
76+
defaultKind string
77+
}
78+
79+
func newCDIDeviceRequestor(logger logger.Interface, image image.CUDA, defaultKind string) deviceRequestor {
80+
c := &cdiDeviceRequestor{
81+
logger: logger,
82+
image: image,
83+
defaultKind: defaultKind,
84+
}
85+
return withUniqueDevices(c)
86+
}
87+
88+
func (c *cdiDeviceRequestor) DeviceRequests() []string {
89+
if c == nil {
90+
return nil
91+
}
92+
annotationDevices, err := getAnnotationDevices(c.image)
6993
if err != nil {
70-
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
94+
c.logger.Warningf("failed to get device requests from container annotations: %v; ignoring.", err)
95+
annotationDevices = nil
7196
}
7297
if len(annotationDevices) > 0 {
73-
return annotationDevices, nil
74-
}
75-
76-
if cfg.AcceptDeviceListAsVolumeMounts {
77-
mountDevices := container.CDIDevicesFromMounts()
78-
if len(mountDevices) > 0 {
79-
return mountDevices, nil
80-
}
98+
return annotationDevices
8199
}
82100

83101
var devices []string
84-
seen := make(map[string]bool)
85-
for _, name := range container.VisibleDevicesFromEnvVar() {
102+
for _, name := range c.image.VisibleDevices() {
86103
if !parser.IsQualifiedName(name) {
87-
name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
88-
}
89-
if seen[name] {
90-
logger.Debugf("Ignoring duplicate device %q", name)
91-
continue
104+
name = fmt.Sprintf("%s=%s", c.defaultKind, name)
92105
}
93106
devices = append(devices, name)
94107
}
95108

96-
if len(devices) == 0 {
97-
return nil, nil
98-
}
99-
100-
if cfg.AcceptEnvvarUnprivileged || container.IsPrivileged() {
101-
return devices, nil
102-
}
103-
104-
logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)
105-
106-
return nil, nil
109+
return devices
107110
}
108111

109112
// getAnnotationDevices returns a list of devices specified in the annotations.
110113
// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of
111114
// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned.
112115
// The list of returned devices is deduplicated.
113116
func getAnnotationDevices(image image.CUDA) ([]string, error) {
114-
seen := make(map[string]bool)
115117
var annotationDevices []string
116118
for _, device := range image.CDIDeviceRequestsFromAnnotations() {
117119
if !parser.IsQualifiedName(device) {
118120
return nil, fmt.Errorf("invalid device name %q in annotations", device)
119121
}
120-
if seen[device] {
121-
continue
122-
}
123-
seen[device] = true
124122
annotationDevices = append(annotationDevices, device)
125123
}
126124
return annotationDevices, nil
@@ -147,15 +145,15 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
147145
if err != nil {
148146
return nil, fmt.Errorf("failed to generate CDI spec: %w", err)
149147
}
150-
cdiModifier, err := cdi.New(
148+
cdiDeviceRequestor, err := cdi.New(
151149
cdi.WithLogger(logger),
152150
cdi.WithSpec(spec.Raw()),
153151
)
154152
if err != nil {
155153
return nil, fmt.Errorf("failed to construct CDI modifier: %w", err)
156154
}
157155

158-
return cdiModifier, nil
156+
return cdiDeviceRequestor, nil
159157
}
160158

161159
func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) {
@@ -193,3 +191,42 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic
193191
spec.WithClass("gpu"),
194192
)
195193
}
194+
195+
func deviceRequestorFromImage(image image.CUDA) deviceRequestor {
196+
return &fromImage{image}
197+
}
198+
199+
type fromImage struct {
200+
image.CUDA
201+
}
202+
203+
func (f *fromImage) DeviceRequests() []string {
204+
if f == nil {
205+
return nil
206+
}
207+
return f.CUDA.VisibleDevices()
208+
}
209+
210+
type deduplicatedDeviceRequestor struct {
211+
deviceRequestor
212+
}
213+
214+
func withUniqueDevices(deviceRequestor deviceRequestor) deviceRequestor {
215+
return &deduplicatedDeviceRequestor{deviceRequestor: deviceRequestor}
216+
}
217+
218+
func (d *deduplicatedDeviceRequestor) DeviceRequests() []string {
219+
if d == nil {
220+
return nil
221+
}
222+
seen := make(map[string]bool)
223+
var devices []string
224+
for _, device := range d.deviceRequestor.DeviceRequests() {
225+
if seen[device] {
226+
continue
227+
}
228+
seen[device] = true
229+
devices = append(devices, device)
230+
}
231+
return devices
232+
}

0 commit comments

Comments
 (0)