Skip to content

Commit bdcdcb7

Browse files
authored
Merge pull request #1132 from elezar/make-cdi-device-extraction-consistent
Make CDI device requests consistent with other methods
2 parents dba15ac + 8be03cf commit bdcdcb7

File tree

7 files changed

+381
-166
lines changed

7 files changed

+381
-166
lines changed

internal/config/image/builder.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ func New(opt ...Option) (CUDA, error) {
5050
if b.logger == nil {
5151
b.logger = logger.New()
5252
}
53-
5453
if b.env == nil {
5554
b.env = make(map[string]string)
5655
}
@@ -81,6 +80,20 @@ func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
8180
}
8281
}
8382

83+
func WithAnnotations(annotations map[string]string) Option {
84+
return func(b *builder) error {
85+
b.annotations = annotations
86+
return nil
87+
}
88+
}
89+
90+
func WithAnnotationsPrefixes(annotationsPrefixes []string) Option {
91+
return func(b *builder) error {
92+
b.annotationsPrefixes = annotationsPrefixes
93+
return nil
94+
}
95+
}
96+
8497
// WithDisableRequire sets the disable require option.
8598
func WithDisableRequire(disableRequire bool) Option {
8699
return func(b *builder) error {

internal/config/image/cuda_image.go

Lines changed: 81 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ const (
4242
type CUDA struct {
4343
logger logger.Interface
4444

45+
annotations map[string]string
4546
env map[string]string
4647
isPrivileged bool
4748
mounts []specs.Mount
4849

50+
annotationsPrefixes []string
4951
acceptDeviceListAsVolumeMounts bool
5052
acceptEnvvarUnprivileged bool
5153
preferredVisibleDeviceEnvVars []string
@@ -54,12 +56,17 @@ type CUDA struct {
5456
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
5557
// The process environment is read (if present) to construc the CUDA Image.
5658
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
59+
if spec == nil {
60+
return New(opts...)
61+
}
62+
5763
var env []string
58-
if spec != nil && spec.Process != nil {
64+
if spec.Process != nil {
5965
env = spec.Process.Env
6066
}
6167

6268
specOpts := []Option{
69+
WithAnnotations(spec.Annotations),
6370
WithEnv(env),
6471
WithMounts(spec.Mounts),
6572
WithPrivileged(IsPrivileged((*OCISpec)(spec))),
@@ -95,6 +102,10 @@ func (i CUDA) IsLegacy() bool {
95102
return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0
96103
}
97104

105+
func (i CUDA) IsPrivileged() bool {
106+
return i.isPrivileged
107+
}
108+
98109
// GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment
99110
// variables.
100111
func (i CUDA) GetRequirements() ([]string, error) {
@@ -212,19 +223,12 @@ func parseMajorMinorVersion(version string) (string, error) {
212223
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
213224
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
214225
var hasCDIdevice bool
215-
for _, device := range i.VisibleDevicesFromEnvVar() {
226+
for _, device := range i.VisibleDevices() {
216227
if !parser.IsQualifiedName(device) {
217228
return false
218229
}
219230
hasCDIdevice = true
220231
}
221-
222-
for _, device := range i.DevicesFromMounts() {
223-
if !strings.HasPrefix(device, "cdi/") {
224-
return false
225-
}
226-
hasCDIdevice = true
227-
}
228232
return hasCDIdevice
229233
}
230234

@@ -234,6 +238,12 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
234238
// In cases where environment variable requests required privileged containers,
235239
// such devices requests are ignored.
236240
func (i CUDA) VisibleDevices() []string {
241+
// If annotation device requests are present, these are preferred.
242+
annotationDeviceRequests := i.cdiDeviceRequestsFromAnnotations()
243+
if len(annotationDeviceRequests) > 0 {
244+
return annotationDeviceRequests
245+
}
246+
237247
// If enabled, try and get the device list from volume mounts first
238248
if i.acceptDeviceListAsVolumeMounts {
239249
volumeMountDeviceRequests := i.visibleDevicesFromMounts()
@@ -260,6 +270,31 @@ func (i CUDA) VisibleDevices() []string {
260270
return nil
261271
}
262272

273+
// cdiDeviceRequestsFromAnnotations returns a list of devices specified in the
274+
// annotations.
275+
// Keys starting with the specified prefixes are considered and expected to
276+
// contain a comma-separated list of fully-qualified CDI devices names.
277+
// The format of the requested devices is not checked and the list is not
278+
// deduplicated.
279+
func (i CUDA) cdiDeviceRequestsFromAnnotations() []string {
280+
if len(i.annotationsPrefixes) == 0 || len(i.annotations) == 0 {
281+
return nil
282+
}
283+
284+
var devices []string
285+
for key, value := range i.annotations {
286+
for _, prefix := range i.annotationsPrefixes {
287+
if strings.HasPrefix(key, prefix) {
288+
devices = append(devices, strings.Split(value, ",")...)
289+
// There is no need to check additional prefixes since we
290+
// typically deduplicate devices in any case.
291+
break
292+
}
293+
}
294+
}
295+
return devices
296+
}
297+
263298
// VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
264299
// If any of the preferredVisibleDeviceEnvVars are present in the image, they
265300
// are used to determine the visible devices. If this is not the case, the
@@ -276,20 +311,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
276311
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
277312
func (i CUDA) visibleDevicesFromMounts() []string {
278313
var devices []string
279-
for _, device := range i.DevicesFromMounts() {
314+
for _, device := range i.requestsFromMounts() {
280315
switch {
281-
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
282-
continue
283316
case strings.HasPrefix(device, volumeMountDevicePrefixImex):
284317
continue
318+
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
319+
name, err := cdiDeviceMountRequest(device).qualifiedName()
320+
if err != nil {
321+
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %v", device, err)
322+
continue
323+
}
324+
devices = append(devices, name)
325+
default:
326+
devices = append(devices, device)
285327
}
286-
devices = append(devices, device)
328+
287329
}
288330
return devices
289331
}
290332

291-
// DevicesFromMounts returns a list of device specified as mounts.
292-
func (i CUDA) DevicesFromMounts() []string {
333+
// requestsFromMounts returns a list of device specified as mounts.
334+
func (i CUDA) requestsFromMounts() []string {
293335
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
294336
seen := make(map[string]bool)
295337
var devices []string
@@ -321,23 +363,30 @@ func (i CUDA) DevicesFromMounts() []string {
321363
return devices
322364
}
323365

324-
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
325-
func (i CUDA) CDIDevicesFromMounts() []string {
326-
var devices []string
327-
for _, mountDevice := range i.DevicesFromMounts() {
328-
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
329-
continue
330-
}
331-
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
332-
if len(parts) != 3 {
333-
continue
334-
}
335-
vendor := parts[0]
336-
class := parts[1]
337-
device := parts[2]
338-
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
366+
// a cdiDeviceMountRequest represents a CDI device requests as a mount.
367+
// Here the host path /dev/null is mounted to a particular path in the container.
368+
// The container path has the form:
369+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
370+
// or
371+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
372+
type cdiDeviceMountRequest string
373+
374+
// qualifiedName returns the fully-qualified name of the CDI device.
375+
func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
376+
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
377+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
339378
}
340-
return devices
379+
380+
requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
381+
if parser.IsQualifiedName(requestedDevice) {
382+
return requestedDevice, nil
383+
}
384+
385+
parts := strings.SplitN(requestedDevice, "/", 3)
386+
if len(parts) != 3 {
387+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
388+
}
389+
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
341390
}
342391

343392
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
@@ -352,7 +401,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
352401
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
353402
func (i CUDA) ImexChannelsFromMounts() []string {
354403
var channels []string
355-
for _, mountDevice := range i.DevicesFromMounts() {
404+
for _, mountDevice := range i.requestsFromMounts() {
356405
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
357406
continue
358407
}

internal/config/image/cuda_image_test.go

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
487487
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
488488
},
489489
{
490-
description: "cdi devices are ignored",
491-
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
492-
expectedDevices: []string{"GPU0", "GPU1"},
490+
description: "cdi devices are included",
491+
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
492+
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
493493
},
494494
{
495495
description: "imex devices are ignored",
@@ -649,6 +649,73 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
649649
}
650650
}
651651

652+
func TestCDIDeviceRequestsFromAnnotations(t *testing.T) {
653+
testCases := []struct {
654+
description string
655+
prefixes []string
656+
annotations map[string]string
657+
expectedDevices []string
658+
}{
659+
{
660+
description: "no annotations",
661+
},
662+
{
663+
description: "no matching annotations",
664+
prefixes: []string{"not-prefix/"},
665+
annotations: map[string]string{
666+
"prefix/foo": "example.com/device=bar",
667+
},
668+
},
669+
{
670+
description: "single matching annotation",
671+
prefixes: []string{"prefix/"},
672+
annotations: map[string]string{
673+
"prefix/foo": "example.com/device=bar",
674+
},
675+
expectedDevices: []string{"example.com/device=bar"},
676+
},
677+
{
678+
description: "multiple matching annotations",
679+
prefixes: []string{"prefix/", "another-prefix/"},
680+
annotations: map[string]string{
681+
"prefix/foo": "example.com/device=bar",
682+
"another-prefix/bar": "example.com/device=baz",
683+
},
684+
expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"},
685+
},
686+
{
687+
description: "multiple matching annotations with duplicate devices",
688+
prefixes: []string{"prefix/", "another-prefix/"},
689+
annotations: map[string]string{
690+
"prefix/foo": "example.com/device=bar",
691+
"another-prefix/bar": "example.com/device=bar",
692+
},
693+
expectedDevices: []string{"example.com/device=bar", "example.com/device=bar"},
694+
},
695+
{
696+
description: "invalid devices are returned as is",
697+
prefixes: []string{"prefix/"},
698+
annotations: map[string]string{
699+
"prefix/foo": "example.com/device",
700+
},
701+
expectedDevices: []string{"example.com/device"},
702+
},
703+
}
704+
705+
for _, tc := range testCases {
706+
t.Run(tc.description, func(t *testing.T) {
707+
image, err := New(
708+
WithAnnotationsPrefixes(tc.prefixes),
709+
WithAnnotations(tc.annotations),
710+
)
711+
require.NoError(t, err)
712+
713+
devices := image.cdiDeviceRequestsFromAnnotations()
714+
require.ElementsMatch(t, tc.expectedDevices, devices)
715+
})
716+
}
717+
}
718+
652719
func makeTestMounts(paths ...string) []specs.Mount {
653720
var mounts []specs.Mount
654721
for _, path := range paths {

internal/info/auto_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
184184
expectedMode: "legacy",
185185
},
186186
{
187-
description: "cdi mount and non-CDI envvar resolves to legacy",
187+
description: "cdi mount and non-CDI envvar resolves to cdi",
188188
mode: "auto",
189189
envmap: map[string]string{
190190
"NVIDIA_VISIBLE_DEVICES": "0",
@@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
197197
"tegra": false,
198198
"nvgpu": false,
199199
},
200+
expectedMode: "cdi",
201+
},
202+
{
203+
description: "non-cdi mount and CDI envvar resolves to legacy",
204+
mode: "auto",
205+
envmap: map[string]string{
206+
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
207+
},
208+
mounts: []string{
209+
"/var/run/nvidia-container-devices/0",
210+
},
211+
info: map[string]bool{
212+
"nvml": true,
213+
"tegra": false,
214+
"nvgpu": false,
215+
},
200216
expectedMode: "legacy",
201217
},
202218
}
@@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
232248
image, _ := image.New(
233249
image.WithEnvMap(tc.envmap),
234250
image.WithMounts(mounts),
251+
image.WithAcceptDeviceListAsVolumeMounts(true),
252+
image.WithAcceptEnvvarUnprivileged(true),
235253
)
236254
mode := resolveMode(logger, tc.mode, image, properties)
237255
require.EqualValues(t, tc.expectedMode, mode)

0 commit comments

Comments
 (0)