@@ -18,6 +18,7 @@ package generate
1818
1919import (
2020 "context"
21+ "errors"
2122 "fmt"
2223 "os"
2324 "path/filepath"
@@ -26,6 +27,7 @@ import (
2627 "github.com/urfave/cli/v3"
2728
2829 cdi "tags.cncf.io/container-device-interface/pkg/parser"
30+ "tags.cncf.io/container-device-interface/specs-go"
2931
3032 "github.com/NVIDIA/go-nvml/pkg/nvml"
3133
@@ -64,6 +66,8 @@ type options struct {
6466 disabledHooks []string
6567 enabledHooks []string
6668
69+ featureFlags []string
70+
6771 csv struct {
6872 files []string
6973 ignorePatterns []string
@@ -222,6 +226,13 @@ func (m command) build() *cli.Command {
222226 Destination : & opts .enabledHooks ,
223227 Sources : cli .EnvVars ("NVIDIA_CTK_CDI_GENERATE_ENABLED_HOOKS" ),
224228 },
229+ & cli.StringSliceFlag {
230+ Name : "feature-flag" ,
231+ Aliases : []string {"feature-flags" },
232+ Usage : "specify feature flags for CDI spec generation" ,
233+ Destination : & opts .featureFlags ,
234+ Sources : cli .EnvVars ("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS" ),
235+ },
225236 },
226237 }
227238
@@ -276,21 +287,18 @@ func (m command) validateFlags(c *cli.Command, opts *options) error {
276287}
277288
278289func (m command ) run (opts * options ) error {
279- spec , err := m .generateSpec (opts )
290+ specs , err := m .generateSpecs (opts )
280291 if err != nil {
281292 return fmt .Errorf ("failed to generate CDI spec: %v" , err )
282293 }
283- m .logger .Infof ("Generated CDI spec with version %v" , spec .Raw ().Version )
284294
285- if opts .output == "" {
286- _ , err := spec .WriteTo (os .Stdout )
287- if err != nil {
288- return fmt .Errorf ("failed to write CDI spec to STDOUT: %v" , err )
289- }
290- return nil
291- }
295+ var errs error
296+ for _ , spec := range specs {
297+ m .logger .Infof ("Generated CDI spec with version %v" , spec .Raw ().Version )
292298
293- return spec .Save (opts .output )
299+ errs = errors .Join (errs , spec .Save (opts .output ))
300+ }
301+ return errs
294302}
295303
296304func formatFromFilename (filename string ) string {
@@ -305,7 +313,34 @@ func formatFromFilename(filename string) string {
305313 return ""
306314}
307315
308- func (m command ) generateSpec (opts * options ) (spec.Interface , error ) {
316+ type generatedSpecs struct {
317+ spec.Interface
318+ filenameInfix string
319+ }
320+
321+ func (g * generatedSpecs ) Save (filename string ) error {
322+ filename = g .updateFilename (filename )
323+
324+ if filename == "" {
325+ _ , err := g .WriteTo (os .Stdout )
326+ if err != nil {
327+ return fmt .Errorf ("failed to write CDI spec to STDOUT: %v" , err )
328+ }
329+ return nil
330+ }
331+
332+ return g .Interface .Save (filename )
333+ }
334+
335+ func (g generatedSpecs ) updateFilename (filename string ) string {
336+ if g .filenameInfix == "" || filename == "" {
337+ return filename
338+ }
339+ ext := filepath .Ext (filepath .Base (filename ))
340+ return strings .TrimSuffix (filename , ext ) + g .filenameInfix + ext
341+ }
342+
343+ func (m command ) generateSpecs (opts * options ) ([]generatedSpecs , error ) {
309344 var deviceNamers []nvcdi.DeviceNamer
310345 for _ , strategy := range opts .deviceNameStrategies {
311346 deviceNamer , err := nvcdi .NewDeviceNamer (strategy )
@@ -329,6 +364,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
329364 nvcdi .WithCSVIgnorePatterns (opts .csv .ignorePatterns ),
330365 nvcdi .WithDisabledHooks (opts .disabledHooks ... ),
331366 nvcdi .WithEnabledHooks (opts .enabledHooks ... ),
367+ nvcdi .WithFeatureFlags (opts .featureFlags ... ),
332368 // We set the following to allow for dependency injection:
333369 nvcdi .WithNvmlLib (opts .nvmllib ),
334370 }
@@ -338,7 +374,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
338374 return nil , fmt .Errorf ("failed to create CDI library: %v" , err )
339375 }
340376
341- deviceSpecs , err := cdilib .GetDeviceSpecsByID ("all" )
377+ allDeviceSpecs , err := cdilib .GetDeviceSpecsByID ("all" )
342378 if err != nil {
343379 return nil , fmt .Errorf ("failed to create device CDI specs: %v" , err )
344380 }
@@ -348,16 +384,79 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
348384 return nil , fmt .Errorf ("failed to create edits common for entities: %v" , err )
349385 }
350386
351- return spec .New (
387+ commonSpecOptions := [] spec.Option {
352388 spec .WithVendor (opts .vendor ),
353- spec .WithClass (opts .class ),
354- spec .WithDeviceSpecs (deviceSpecs ),
355389 spec .WithEdits (* commonEdits .ContainerEdits ),
356390 spec .WithFormat (opts .format ),
357391 spec .WithMergedDeviceOptions (
358392 transform .WithName (allDeviceName ),
359393 transform .WithSkipIfExists (true ),
360394 ),
361395 spec .WithPermissions (0644 ),
396+ }
397+
398+ fullSpec , err := spec .New (
399+ append (commonSpecOptions ,
400+ spec .WithClass (opts .class ),
401+ spec .WithDeviceSpecs (allDeviceSpecs ),
402+ )... ,
362403 )
404+ if err != nil {
405+ return nil , err
406+ }
407+ var allSpecs []generatedSpecs
408+
409+ allSpecs = append (allSpecs , generatedSpecs {Interface : fullSpec , filenameInfix : "" })
410+
411+ deviceSpecsByDeviceCoherence := (deviceSpecs )(allDeviceSpecs ).splitOnAnnotation ("gpu.nvidia.com/coherent" )
412+
413+ if coherentDeviceSpecs := deviceSpecsByDeviceCoherence ["gpu.nvidia.com/coherent=true" ]; len (coherentDeviceSpecs ) > 0 {
414+ infix := ".coherent"
415+ coherentSpecs , err := spec .New (
416+ append (commonSpecOptions ,
417+ spec .WithClass (opts .class + infix ),
418+ spec .WithDeviceSpecs (coherentDeviceSpecs ),
419+ )... ,
420+ )
421+ if err != nil {
422+ return nil , err
423+ }
424+ allSpecs = append (allSpecs , generatedSpecs {Interface : coherentSpecs , filenameInfix : infix })
425+ }
426+
427+ if noncoherentDeviceSpecs := deviceSpecsByDeviceCoherence ["gpu.nvidia.com/coherent=false" ]; len (noncoherentDeviceSpecs ) > 0 {
428+ infix := ".noncoherent"
429+ noncoherentSpecs , err := spec .New (
430+ append (commonSpecOptions ,
431+ spec .WithClass (opts .class + infix ),
432+ spec .WithDeviceSpecs (noncoherentDeviceSpecs ),
433+ )... ,
434+ )
435+
436+ if err != nil {
437+ return nil , err
438+ }
439+ allSpecs = append (allSpecs , generatedSpecs {Interface : noncoherentSpecs , filenameInfix : infix })
440+ }
441+
442+ return allSpecs , nil
443+ }
444+
445+ type deviceSpecs []specs.Device
446+
447+ func (d deviceSpecs ) splitOnAnnotation (key string ) map [string ][]specs.Device {
448+ splitSpecs := make (map [string ][]specs.Device )
449+
450+ for _ , deviceSpec := range d {
451+ if len (deviceSpec .Annotations ) == 0 {
452+ continue
453+ }
454+ value , ok := deviceSpec .Annotations [key ]
455+ if ! ok {
456+ continue
457+ }
458+ splitSpecs [key + "=" + value ] = append (splitSpecs [key + "=" + value ], deviceSpec )
459+ }
460+
461+ return splitSpecs
363462}
0 commit comments