Skip to content

Commit a10c54e

Browse files
authored
Merge pull request #1247 from elezar/coherent-non-coherent
Generate CDI specs for coherent and non-coherent devices
2 parents c350d13 + 868963b commit a10c54e

File tree

24 files changed

+2081
-1622
lines changed

24 files changed

+2081
-1622
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package generate
1818

1919
import (
2020
"context"
21+
"errors"
2122
"fmt"
2223
"os"
2324
"path/filepath"
@@ -26,6 +27,7 @@ import (
2627
"github.com/urfave/cli/v3"
2728

2829
cdi "tags.cncf.io/container-device-interface/pkg/parser"
30+
"tags.cncf.io/container-device-interface/specs-go"
2931

3032
"github.com/NVIDIA/go-nvml/pkg/nvml"
3133

@@ -64,6 +66,8 @@ type options struct {
6466
disabledHooks []string
6567
enabledHooks []string
6668

69+
featureFlags []string
70+
6771
csv struct {
6872
files []string
6973
ignorePatterns []string
@@ -222,6 +226,13 @@ func (m command) build() *cli.Command {
222226
Destination: &opts.enabledHooks,
223227
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_ENABLED_HOOKS"),
224228
},
229+
&cli.StringSliceFlag{
230+
Name: "feature-flag",
231+
Aliases: []string{"feature-flags"},
232+
Usage: "specify feature flags for CDI spec generation",
233+
Destination: &opts.featureFlags,
234+
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS"),
235+
},
225236
},
226237
}
227238

@@ -276,21 +287,18 @@ func (m command) validateFlags(c *cli.Command, opts *options) error {
276287
}
277288

278289
func (m command) run(opts *options) error {
279-
spec, err := m.generateSpec(opts)
290+
specs, err := m.generateSpecs(opts)
280291
if err != nil {
281292
return fmt.Errorf("failed to generate CDI spec: %v", err)
282293
}
283-
m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version)
284294

285-
if opts.output == "" {
286-
_, err := spec.WriteTo(os.Stdout)
287-
if err != nil {
288-
return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err)
289-
}
290-
return nil
291-
}
295+
var errs error
296+
for _, spec := range specs {
297+
m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version)
292298

293-
return spec.Save(opts.output)
299+
errs = errors.Join(errs, spec.Save(opts.output))
300+
}
301+
return errs
294302
}
295303

296304
func formatFromFilename(filename string) string {
@@ -305,7 +313,34 @@ func formatFromFilename(filename string) string {
305313
return ""
306314
}
307315

308-
func (m command) generateSpec(opts *options) (spec.Interface, error) {
316+
type generatedSpecs struct {
317+
spec.Interface
318+
filenameInfix string
319+
}
320+
321+
func (g *generatedSpecs) Save(filename string) error {
322+
filename = g.updateFilename(filename)
323+
324+
if filename == "" {
325+
_, err := g.WriteTo(os.Stdout)
326+
if err != nil {
327+
return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err)
328+
}
329+
return nil
330+
}
331+
332+
return g.Interface.Save(filename)
333+
}
334+
335+
func (g generatedSpecs) updateFilename(filename string) string {
336+
if g.filenameInfix == "" || filename == "" {
337+
return filename
338+
}
339+
ext := filepath.Ext(filepath.Base(filename))
340+
return strings.TrimSuffix(filename, ext) + g.filenameInfix + ext
341+
}
342+
343+
func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
309344
var deviceNamers []nvcdi.DeviceNamer
310345
for _, strategy := range opts.deviceNameStrategies {
311346
deviceNamer, err := nvcdi.NewDeviceNamer(strategy)
@@ -329,6 +364,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
329364
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns),
330365
nvcdi.WithDisabledHooks(opts.disabledHooks...),
331366
nvcdi.WithEnabledHooks(opts.enabledHooks...),
367+
nvcdi.WithFeatureFlags(opts.featureFlags...),
332368
// We set the following to allow for dependency injection:
333369
nvcdi.WithNvmlLib(opts.nvmllib),
334370
}
@@ -338,7 +374,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
338374
return nil, fmt.Errorf("failed to create CDI library: %v", err)
339375
}
340376

341-
deviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
377+
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
342378
if err != nil {
343379
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
344380
}
@@ -348,16 +384,79 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
348384
return nil, fmt.Errorf("failed to create edits common for entities: %v", err)
349385
}
350386

351-
return spec.New(
387+
commonSpecOptions := []spec.Option{
352388
spec.WithVendor(opts.vendor),
353-
spec.WithClass(opts.class),
354-
spec.WithDeviceSpecs(deviceSpecs),
355389
spec.WithEdits(*commonEdits.ContainerEdits),
356390
spec.WithFormat(opts.format),
357391
spec.WithMergedDeviceOptions(
358392
transform.WithName(allDeviceName),
359393
transform.WithSkipIfExists(true),
360394
),
361395
spec.WithPermissions(0644),
396+
}
397+
398+
fullSpec, err := spec.New(
399+
append(commonSpecOptions,
400+
spec.WithClass(opts.class),
401+
spec.WithDeviceSpecs(allDeviceSpecs),
402+
)...,
362403
)
404+
if err != nil {
405+
return nil, err
406+
}
407+
var allSpecs []generatedSpecs
408+
409+
allSpecs = append(allSpecs, generatedSpecs{Interface: fullSpec, filenameInfix: ""})
410+
411+
deviceSpecsByDeviceCoherence := (deviceSpecs)(allDeviceSpecs).splitOnAnnotation("gpu.nvidia.com/coherent")
412+
413+
if coherentDeviceSpecs := deviceSpecsByDeviceCoherence["gpu.nvidia.com/coherent=true"]; len(coherentDeviceSpecs) > 0 {
414+
infix := ".coherent"
415+
coherentSpecs, err := spec.New(
416+
append(commonSpecOptions,
417+
spec.WithClass(opts.class+infix),
418+
spec.WithDeviceSpecs(coherentDeviceSpecs),
419+
)...,
420+
)
421+
if err != nil {
422+
return nil, err
423+
}
424+
allSpecs = append(allSpecs, generatedSpecs{Interface: coherentSpecs, filenameInfix: infix})
425+
}
426+
427+
if noncoherentDeviceSpecs := deviceSpecsByDeviceCoherence["gpu.nvidia.com/coherent=false"]; len(noncoherentDeviceSpecs) > 0 {
428+
infix := ".noncoherent"
429+
noncoherentSpecs, err := spec.New(
430+
append(commonSpecOptions,
431+
spec.WithClass(opts.class+infix),
432+
spec.WithDeviceSpecs(noncoherentDeviceSpecs),
433+
)...,
434+
)
435+
436+
if err != nil {
437+
return nil, err
438+
}
439+
allSpecs = append(allSpecs, generatedSpecs{Interface: noncoherentSpecs, filenameInfix: infix})
440+
}
441+
442+
return allSpecs, nil
443+
}
444+
445+
type deviceSpecs []specs.Device
446+
447+
func (d deviceSpecs) splitOnAnnotation(key string) map[string][]specs.Device {
448+
splitSpecs := make(map[string][]specs.Device)
449+
450+
for _, deviceSpec := range d {
451+
if len(deviceSpec.Annotations) == 0 {
452+
continue
453+
}
454+
value, ok := deviceSpec.Annotations[key]
455+
if !ok {
456+
continue
457+
}
458+
splitSpecs[key+"="+value] = append(splitSpecs[key+"="+value], deviceSpec)
459+
}
460+
461+
return splitSpecs
363462
}

cmd/nvidia-ctk/cdi/generate/generate_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/stretchr/testify/require"
2929

3030
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
31+
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3132
)
3233

3334
func TestGenerateSpec(t *testing.T) {
@@ -476,12 +477,15 @@ containerEdits:
476477
}
477478
tc.options.nvmllib = server
478479

479-
spec, err := c.generateSpec(&tc.options)
480+
tc.options.featureFlags = []string{string(nvcdi.FeatureDisableCoherentAnnotations)}
481+
specs, err := c.generateSpecs(&tc.options)
480482
require.ErrorIs(t, err, tc.expectedError)
481483

482484
var buf bytes.Buffer
483-
_, err = spec.WriteTo(&buf)
484-
require.NoError(t, err)
485+
for _, spec := range specs {
486+
_, err = spec.WriteTo(&buf)
487+
require.NoError(t, err)
488+
}
485489

486490
require.Equal(t, strings.ReplaceAll(tc.expectedSpec, "{{ .driverRoot }}", driverRoot), buf.String())
487491
})

0 commit comments

Comments
 (0)