@@ -18,6 +18,7 @@ package generate
1818
1919import (
2020 "context"
21+ "errors"
2122 "fmt"
2223 "os"
2324 "path/filepath"
@@ -26,6 +27,7 @@ import (
2627 "github.com/urfave/cli/v3"
2728
2829 cdi "tags.cncf.io/container-device-interface/pkg/parser"
30+ "tags.cncf.io/container-device-interface/specs-go"
2931
3032 "github.com/NVIDIA/go-nvml/pkg/nvml"
3133
@@ -63,6 +65,8 @@ type options struct {
6365 librarySearchPaths []string
6466 disabledHooks []string
6567
68+ featureFlags []string
69+
6670 csv struct {
6771 files []string
6872 ignorePatterns []string
@@ -214,6 +218,13 @@ func (m command) build() *cli.Command {
214218 Destination : & opts .disabledHooks ,
215219 Sources : cli .EnvVars ("NVIDIA_CTK_CDI_GENERATE_DISABLED_HOOKS" ),
216220 },
221+ & cli.StringSliceFlag {
222+ Name : "feature-flag" ,
223+ Aliases : []string {"feature-flags" },
224+ Usage : "specify feature flags for CDI spec generation" ,
225+ Destination : & opts .featureFlags ,
226+ Sources : cli .EnvVars ("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS" ),
227+ },
217228 },
218229 }
219230
@@ -262,21 +273,18 @@ func (m command) validateFlags(c *cli.Command, opts *options) error {
262273}
263274
264275func (m command ) run (opts * options ) error {
265- spec , err := m .generateSpec (opts )
276+ specs , err := m .generateSpecs (opts )
266277 if err != nil {
267278 return fmt .Errorf ("failed to generate CDI spec: %v" , err )
268279 }
269- m .logger .Infof ("Generated CDI spec with version %v" , spec .Raw ().Version )
270280
271- if opts .output == "" {
272- _ , err := spec .WriteTo (os .Stdout )
273- if err != nil {
274- return fmt .Errorf ("failed to write CDI spec to STDOUT: %v" , err )
275- }
276- return nil
277- }
281+ var errs error
282+ for _ , spec := range specs {
283+ m .logger .Infof ("Generated CDI spec with version %v" , spec .Raw ().Version )
278284
279- return spec .Save (opts .output )
285+ errs = errors .Join (errs , spec .Save (opts .output ))
286+ }
287+ return errs
280288}
281289
282290func formatFromFilename (filename string ) string {
@@ -291,7 +299,34 @@ func formatFromFilename(filename string) string {
291299 return ""
292300}
293301
294- func (m command ) generateSpec (opts * options ) (spec.Interface , error ) {
302+ type generatedSpecs struct {
303+ spec.Interface
304+ filenameInfix string
305+ }
306+
307+ func (g * generatedSpecs ) Save (filename string ) error {
308+ filename = g .updateFilename (filename )
309+
310+ if filename == "" {
311+ _ , err := g .WriteTo (os .Stdout )
312+ if err != nil {
313+ return fmt .Errorf ("failed to write CDI spec to STDOUT: %v" , err )
314+ }
315+ return nil
316+ }
317+
318+ return g .Interface .Save (filename )
319+ }
320+
321+ func (g generatedSpecs ) updateFilename (filename string ) string {
322+ if g .filenameInfix == "" || filename == "" {
323+ return filename
324+ }
325+ ext := filepath .Ext (filepath .Base (filename ))
326+ return strings .TrimSuffix (filename , ext ) + g .filenameInfix + ext
327+ }
328+
329+ func (m command ) generateSpecs (opts * options ) ([]generatedSpecs , error ) {
295330 var deviceNamers []nvcdi.DeviceNamer
296331 for _ , strategy := range opts .deviceNameStrategies {
297332 deviceNamer , err := nvcdi .NewDeviceNamer (strategy )
@@ -313,6 +348,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
313348 nvcdi .WithLibrarySearchPaths (opts .librarySearchPaths ),
314349 nvcdi .WithCSVFiles (opts .csv .files ),
315350 nvcdi .WithCSVIgnorePatterns (opts .csv .ignorePatterns ),
351+ nvcdi .WithFeatureFlags (opts .featureFlags ... ),
316352 // We set the following to allow for dependency injection:
317353 nvcdi .WithNvmlLib (opts .nvmllib ),
318354 }
@@ -326,7 +362,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
326362 return nil , fmt .Errorf ("failed to create CDI library: %v" , err )
327363 }
328364
329- deviceSpecs , err := cdilib .GetDeviceSpecsByID ("all" )
365+ allDeviceSpecs , err := cdilib .GetDeviceSpecsByID ("all" )
330366 if err != nil {
331367 return nil , fmt .Errorf ("failed to create device CDI specs: %v" , err )
332368 }
@@ -336,16 +372,79 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
336372 return nil , fmt .Errorf ("failed to create edits common for entities: %v" , err )
337373 }
338374
339- return spec .New (
375+ commonSpecOptions := [] spec.Option {
340376 spec .WithVendor (opts .vendor ),
341- spec .WithClass (opts .class ),
342- spec .WithDeviceSpecs (deviceSpecs ),
343377 spec .WithEdits (* commonEdits .ContainerEdits ),
344378 spec .WithFormat (opts .format ),
345379 spec .WithMergedDeviceOptions (
346380 transform .WithName (allDeviceName ),
347381 transform .WithSkipIfExists (true ),
348382 ),
349383 spec .WithPermissions (0644 ),
384+ }
385+
386+ fullSpec , err := spec .New (
387+ append (commonSpecOptions ,
388+ spec .WithClass (opts .class ),
389+ spec .WithDeviceSpecs (allDeviceSpecs ),
390+ )... ,
350391 )
392+ if err != nil {
393+ return nil , err
394+ }
395+ var allSpecs []generatedSpecs
396+
397+ allSpecs = append (allSpecs , generatedSpecs {Interface : fullSpec , filenameInfix : "" })
398+
399+ deviceSpecsByDeviceCoherence := (deviceSpecs )(allDeviceSpecs ).splitOnAnnotation ("gpu.nvidia.com/coherent" )
400+
401+ if coherentDeviceSpecs := deviceSpecsByDeviceCoherence ["gpu.nvidia.com/coherent=true" ]; len (coherentDeviceSpecs ) > 0 {
402+ infix := ".coherent"
403+ coherentSpecs , err := spec .New (
404+ append (commonSpecOptions ,
405+ spec .WithClass (opts .class + infix ),
406+ spec .WithDeviceSpecs (coherentDeviceSpecs ),
407+ )... ,
408+ )
409+ if err != nil {
410+ return nil , err
411+ }
412+ allSpecs = append (allSpecs , generatedSpecs {Interface : coherentSpecs , filenameInfix : infix })
413+ }
414+
415+ if noncoherentDeviceSpecs := deviceSpecsByDeviceCoherence ["gpu.nvidia.com/coherent=false" ]; len (noncoherentDeviceSpecs ) > 0 {
416+ infix := ".noncoherent"
417+ noncoherentSpecs , err := spec .New (
418+ append (commonSpecOptions ,
419+ spec .WithClass (opts .class + infix ),
420+ spec .WithDeviceSpecs (noncoherentDeviceSpecs ),
421+ )... ,
422+ )
423+
424+ if err != nil {
425+ return nil , err
426+ }
427+ allSpecs = append (allSpecs , generatedSpecs {Interface : noncoherentSpecs , filenameInfix : infix })
428+ }
429+
430+ return allSpecs , nil
431+ }
432+
433+ type deviceSpecs []specs.Device
434+
435+ func (d deviceSpecs ) splitOnAnnotation (key string ) map [string ][]specs.Device {
436+ splitSpecs := make (map [string ][]specs.Device )
437+
438+ for _ , deviceSpec := range d {
439+ if len (deviceSpec .Annotations ) == 0 {
440+ continue
441+ }
442+ value , ok := deviceSpec .Annotations [key ]
443+ if ! ok {
444+ continue
445+ }
446+ splitSpecs [key + "=" + value ] = append (splitSpecs [key + "=" + value ], deviceSpec )
447+ }
448+
449+ return splitSpecs
351450}
0 commit comments