Skip to content

Commit 1dea726

Browse files
committed
Generate separate specs for coherent and noncoherent devices
With this change the nvidia-ctk cdi generate command generates CDI specs based on whether a device supports coherent access to system memory or not. In this case "regular" nvidia.com/gpu CDI specs are generated for all devices as well as nvidia.com/gpu.coherent and nvidia.com/gpu.noncoherent for devices that are either coherent or non-coherent. Adding the --feature-flag=disable-coherent-annotations command line argument to the nvidia-ctk cdi generate command will disable this. The "disable-coherent-annotations" feature flag can also be set in the nvcdi API in which case the generated CDI device specification will not include annotations indicating coherence. Signed-off-by: Evan Lezar <[email protected]>
1 parent b018e56 commit 1dea726

File tree

7 files changed

+177
-27
lines changed

7 files changed

+177
-27
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package generate
1818

1919
import (
2020
"context"
21+
"errors"
2122
"fmt"
2223
"os"
2324
"path/filepath"
@@ -26,6 +27,7 @@ import (
2627
"github.com/urfave/cli/v3"
2728

2829
cdi "tags.cncf.io/container-device-interface/pkg/parser"
30+
"tags.cncf.io/container-device-interface/specs-go"
2931

3032
"github.com/NVIDIA/go-nvml/pkg/nvml"
3133

@@ -63,6 +65,8 @@ type options struct {
6365
librarySearchPaths []string
6466
disabledHooks []string
6567

68+
featureFlags []string
69+
6670
csv struct {
6771
files []string
6872
ignorePatterns []string
@@ -214,6 +218,13 @@ func (m command) build() *cli.Command {
214218
Destination: &opts.disabledHooks,
215219
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DISABLED_HOOKS"),
216220
},
221+
&cli.StringSliceFlag{
222+
Name: "feature-flag",
223+
Aliases: []string{"feature-flags"},
224+
Usage: "specify feature flags for CDI spec generation",
225+
Destination: &opts.featureFlags,
226+
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS"),
227+
},
217228
},
218229
}
219230

@@ -262,21 +273,18 @@ func (m command) validateFlags(c *cli.Command, opts *options) error {
262273
}
263274

264275
func (m command) run(opts *options) error {
265-
spec, err := m.generateSpec(opts)
276+
specs, err := m.generateSpecs(opts)
266277
if err != nil {
267278
return fmt.Errorf("failed to generate CDI spec: %v", err)
268279
}
269-
m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version)
270280

271-
if opts.output == "" {
272-
_, err := spec.WriteTo(os.Stdout)
273-
if err != nil {
274-
return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err)
275-
}
276-
return nil
277-
}
281+
var errs error
282+
for _, spec := range specs {
283+
m.logger.Infof("Generated CDI spec with version %v", spec.Raw().Version)
278284

279-
return spec.Save(opts.output)
285+
errs = errors.Join(errs, spec.Save(opts.output))
286+
}
287+
return errs
280288
}
281289

282290
func formatFromFilename(filename string) string {
@@ -291,7 +299,34 @@ func formatFromFilename(filename string) string {
291299
return ""
292300
}
293301

294-
func (m command) generateSpec(opts *options) (spec.Interface, error) {
302+
type generatedSpecs struct {
303+
spec.Interface
304+
filenameInfix string
305+
}
306+
307+
func (g *generatedSpecs) Save(filename string) error {
308+
filename = g.updateFilename(filename)
309+
310+
if filename == "" {
311+
_, err := g.WriteTo(os.Stdout)
312+
if err != nil {
313+
return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err)
314+
}
315+
return nil
316+
}
317+
318+
return g.Interface.Save(filename)
319+
}
320+
321+
func (g generatedSpecs) updateFilename(filename string) string {
322+
if g.filenameInfix == "" || filename == "" {
323+
return filename
324+
}
325+
ext := filepath.Ext(filepath.Base(filename))
326+
return strings.TrimSuffix(filename, ext) + g.filenameInfix + ext
327+
}
328+
329+
func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
295330
var deviceNamers []nvcdi.DeviceNamer
296331
for _, strategy := range opts.deviceNameStrategies {
297332
deviceNamer, err := nvcdi.NewDeviceNamer(strategy)
@@ -313,6 +348,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
313348
nvcdi.WithLibrarySearchPaths(opts.librarySearchPaths),
314349
nvcdi.WithCSVFiles(opts.csv.files),
315350
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns),
351+
nvcdi.WithFeatureFlags(opts.featureFlags...),
316352
// We set the following to allow for dependency injection:
317353
nvcdi.WithNvmlLib(opts.nvmllib),
318354
}
@@ -326,7 +362,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
326362
return nil, fmt.Errorf("failed to create CDI library: %v", err)
327363
}
328364

329-
deviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
365+
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
330366
if err != nil {
331367
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
332368
}
@@ -336,16 +372,79 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
336372
return nil, fmt.Errorf("failed to create edits common for entities: %v", err)
337373
}
338374

339-
return spec.New(
375+
commonSpecOptions := []spec.Option{
340376
spec.WithVendor(opts.vendor),
341-
spec.WithClass(opts.class),
342-
spec.WithDeviceSpecs(deviceSpecs),
343377
spec.WithEdits(*commonEdits.ContainerEdits),
344378
spec.WithFormat(opts.format),
345379
spec.WithMergedDeviceOptions(
346380
transform.WithName(allDeviceName),
347381
transform.WithSkipIfExists(true),
348382
),
349383
spec.WithPermissions(0644),
384+
}
385+
386+
fullSpec, err := spec.New(
387+
append(commonSpecOptions,
388+
spec.WithClass(opts.class),
389+
spec.WithDeviceSpecs(allDeviceSpecs),
390+
)...,
350391
)
392+
if err != nil {
393+
return nil, err
394+
}
395+
var allSpecs []generatedSpecs
396+
397+
allSpecs = append(allSpecs, generatedSpecs{Interface: fullSpec, filenameInfix: ""})
398+
399+
deviceSpecsByDeviceCoherence := (deviceSpecs)(allDeviceSpecs).splitOnAnnotation("gpu.nvidia.com/coherent")
400+
401+
if coherentDeviceSpecs := deviceSpecsByDeviceCoherence["gpu.nvidia.com/coherent=true"]; len(coherentDeviceSpecs) > 0 {
402+
infix := ".coherent"
403+
coherentSpecs, err := spec.New(
404+
append(commonSpecOptions,
405+
spec.WithClass(opts.class+infix),
406+
spec.WithDeviceSpecs(coherentDeviceSpecs),
407+
)...,
408+
)
409+
if err != nil {
410+
return nil, err
411+
}
412+
allSpecs = append(allSpecs, generatedSpecs{Interface: coherentSpecs, filenameInfix: infix})
413+
}
414+
415+
if noncoherentDeviceSpecs := deviceSpecsByDeviceCoherence["gpu.nvidia.com/coherent=false"]; len(noncoherentDeviceSpecs) > 0 {
416+
infix := ".noncoherent"
417+
noncoherentSpecs, err := spec.New(
418+
append(commonSpecOptions,
419+
spec.WithClass(opts.class+infix),
420+
spec.WithDeviceSpecs(noncoherentDeviceSpecs),
421+
)...,
422+
)
423+
424+
if err != nil {
425+
return nil, err
426+
}
427+
allSpecs = append(allSpecs, generatedSpecs{Interface: noncoherentSpecs, filenameInfix: infix})
428+
}
429+
430+
return allSpecs, nil
431+
}
432+
433+
type deviceSpecs []specs.Device
434+
435+
func (d deviceSpecs) splitOnAnnotation(key string) map[string][]specs.Device {
436+
splitSpecs := make(map[string][]specs.Device)
437+
438+
for _, deviceSpec := range d {
439+
if len(deviceSpec.Annotations) == 0 {
440+
continue
441+
}
442+
value, ok := deviceSpec.Annotations[key]
443+
if !ok {
444+
continue
445+
}
446+
splitSpecs[key+"="+value] = append(splitSpecs[key+"="+value], deviceSpec)
447+
}
448+
449+
return splitSpecs
351450
}

cmd/nvidia-ctk/cdi/generate/generate_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/stretchr/testify/require"
2929

3030
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
31+
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3132
)
3233

3334
func TestGenerateSpec(t *testing.T) {
@@ -391,12 +392,15 @@ containerEdits:
391392
}
392393
tc.options.nvmllib = server
393394

394-
spec, err := c.generateSpec(&tc.options)
395+
tc.options.featureFlags = []string{string(nvcdi.FeatureDisableCoherentAnnotations)}
396+
specs, err := c.generateSpecs(&tc.options)
395397
require.ErrorIs(t, err, tc.expectedError)
396398

397399
var buf bytes.Buffer
398-
_, err = spec.WriteTo(&buf)
399-
require.NoError(t, err)
400+
for _, spec := range specs {
401+
_, err = spec.WriteTo(&buf)
402+
require.NoError(t, err)
403+
}
400404

401405
require.Equal(t, strings.ReplaceAll(tc.expectedSpec, "{{ .driverRoot }}", driverRoot), buf.String())
402406
})

pkg/nvcdi/api.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,7 @@ const (
8080
// FeatureDisableNvsandboxUtils disables the use of nvsandboxutils when
8181
// querying devices.
8282
FeatureDisableNvsandboxUtils = FeatureFlag("disable-nvsandbox-utils")
83+
// FeatureDisableCoherentAnnotations disables the addition of annotations
84+
// coherent or non-coherent devices.
85+
FeatureDisableCoherentAnnotations = FeatureFlag("disable-coherent-annotations")
8386
)

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ type fullGPUDeviceSpecGenerator struct {
3636
*nvmllib
3737
uuid string
3838
index int
39+
40+
featureFlags map[FeatureFlag]bool
3941
}
4042

4143
var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
@@ -44,7 +46,7 @@ func (l *fullGPUDeviceSpecGenerator) GetUUID() (string, error) {
4446
return l.uuid, nil
4547
}
4648

47-
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromDevice(index int, d device.Device) (*fullGPUDeviceSpecGenerator, error) {
49+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromDevice(index int, d device.Device, featureFlags map[FeatureFlag]bool) (*fullGPUDeviceSpecGenerator, error) {
4850
uuid, ret := d.GetUUID()
4951
if ret != nvml.SUCCESS {
5052
return nil, fmt.Errorf("failed to get device UUID: %v", ret)
@@ -53,12 +55,14 @@ func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromDevice(index int, d device.De
5355
nvmllib: l,
5456
uuid: uuid,
5557
index: index,
58+
59+
featureFlags: featureFlags,
5660
}
5761

5862
return e, nil
5963
}
6064

61-
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(uuid string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
65+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(uuid string, nvmlDevice nvml.Device, featureFlags map[FeatureFlag]bool) (DeviceSpecGenerator, error) {
6266
index, ret := nvmlDevice.GetIndex()
6367
if ret != nvml.SUCCESS {
6468
return nil, fmt.Errorf("failed to get device index: %v", ret)
@@ -68,6 +72,8 @@ func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(uuid string, nvmlD
6872
nvmllib: l,
6973
uuid: uuid,
7074
index: index,
75+
76+
featureFlags: featureFlags,
7177
}
7278
return e, nil
7379
}
@@ -83,11 +89,17 @@ func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
8389
return nil, fmt.Errorf("failed to get device names: %w", err)
8490
}
8591

92+
annotations, err := l.getDeviceAnnotations()
93+
if err != nil {
94+
l.logger.Warning("Ignoring error getting device annotations for device(s) %v: %v", names, err)
95+
annotations = nil
96+
}
8697
var deviceSpecs []specs.Device
8798
for _, name := range names {
8899
deviceSpec := specs.Device{
89100
Name: name,
90101
ContainerEdits: *deviceEdits.ContainerEdits,
102+
Annotations: annotations,
91103
}
92104
deviceSpecs = append(deviceSpecs, deviceSpec)
93105
}
@@ -99,6 +111,29 @@ func (l *fullGPUDeviceSpecGenerator) device() (device.Device, error) {
99111
return l.devicelib.NewDeviceByUUID(l.uuid)
100112
}
101113

114+
func (l *fullGPUDeviceSpecGenerator) getDeviceAnnotations() (map[string]string, error) {
115+
if l.featureFlags[FeatureDisableCoherentAnnotations] {
116+
return nil, nil
117+
}
118+
119+
device, err := l.device()
120+
if err != nil {
121+
return nil, err
122+
}
123+
124+
// TODO: Should we distinguish between not-supported and disabled?
125+
isCoherent, err := device.IsCoherent()
126+
if err != nil {
127+
return nil, fmt.Errorf("failed to check device coherence: %w", err)
128+
}
129+
130+
annotations := map[string]string{
131+
"gpu.nvidia.com/coherent": fmt.Sprintf("%v", isCoherent),
132+
}
133+
134+
return annotations, nil
135+
}
136+
102137
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
103138
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
104139
device, err := l.device()

pkg/nvcdi/lib-nvml.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (l *nvmllib) newDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvm
102102
return l.newMIGDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice)
103103
}
104104

105-
return l.newFullGPUDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice)
105+
return l.newFullGPUDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice, l.featureFlags)
106106
}
107107

108108
// getDeviceSpecGeneratorsForAllDevices returns the CDI device spec generators
@@ -118,7 +118,7 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
118118
if isMigEnabled {
119119
return nil
120120
}
121-
fullGPU, err := l.newFullGPUDeviceSpecGeneratorFromDevice(i, d)
121+
fullGPU, err := l.newFullGPUDeviceSpecGeneratorFromDevice(i, d, l.featureFlags)
122122
if err != nil {
123123
return err
124124
}

pkg/nvcdi/mig-device-nvml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (l *migDeviceSpecGenerator) GetUUID() (string, error) {
4141
}
4242

4343
func (l *nvmllib) newMIGDeviceSpecGeneratorFromDevice(i int, d device.Device, j int, m device.MigDevice) (*migDeviceSpecGenerator, error) {
44-
parent, err := l.newFullGPUDeviceSpecGeneratorFromDevice(i, d)
44+
parent, err := l.newFullGPUDeviceSpecGeneratorFromDevice(i, d, map[FeatureFlag]bool{FeatureDisableCoherentAnnotations: true})
4545
if err != nil {
4646
return nil, err
4747
}

pkg/nvcdi/options.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,22 @@ func WithDisabledHook[T string | HookName](hook T) Option {
165165
}
166166
}
167167

168-
// WithFeatureFlag allows specified features to be toggled on.
169-
// This option can be specified multiple times for each feature flag.
170-
func WithFeatureFlag(featureFlag FeatureFlag) Option {
168+
// WithFeatureFlags allows the specified set of features to be toggled on.
169+
func WithFeatureFlags[T string | FeatureFlag](featureFlags ...T) Option {
171170
return func(o *nvcdilib) {
172171
if o.featureFlags == nil {
173172
o.featureFlags = make(map[FeatureFlag]bool)
174173
}
175-
o.featureFlags[featureFlag] = true
174+
for _, featureFlag := range featureFlags {
175+
o.featureFlags[FeatureFlag(featureFlag)] = true
176+
}
176177
}
177178
}
179+
180+
// WithFeatureFlag allows specified features to be toggled on.
181+
// This option can be specified multiple times for each feature flag.
182+
//
183+
// Deprecated: Use WithFeatureFlags
184+
func WithFeatureFlag[T string | FeatureFlag](featureFlag T) Option {
185+
return WithFeatureFlags(featureFlag)
186+
}

0 commit comments

Comments
 (0)