@@ -62,6 +62,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
6262 return nil , fmt .Errorf ("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices" )
6363 }
6464 if len (automaticDevices ) > 0 {
65+ automaticDevices = append (automaticDevices , gatedDevices (image ).DeviceRequests ()... )
6566 automaticModifier , err := newAutomaticCDISpecModifier (logger , cfg , automaticDevices )
6667 if err == nil {
6768 return automaticModifier , nil
@@ -111,6 +112,29 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string {
111112 return devices
112113}
113114
115+ type gatedDevices image.CUDA
116+
117+ // DeviceRequests returns a list of devices that are required for gated devices.
118+ func (g gatedDevices ) DeviceRequests () []string {
119+ i := (image .CUDA )(g )
120+
121+ var devices []string
122+ if i .Getenv ("NVIDIA_GDS" ) == "enabled" {
123+ devices = append (devices , "mode=gds" )
124+ }
125+ if i .Getenv ("NVIDIA_MOFED" ) == "enabled" {
126+ devices = append (devices , "mode=mofed" )
127+ }
128+ if i .Getenv ("NVIDIA_GDRCOPY" ) == "enabled" {
129+ devices = append (devices , "mode=gdrcopy" )
130+ }
131+ if i .Getenv ("NVIDIA_NVSWITCH" ) == "enabled" {
132+ devices = append (devices , "mode=nvswitch" )
133+ }
134+
135+ return devices
136+ }
137+
114138// filterAutomaticDevices searches for "automatic" device names in the input slice.
115139// "Automatic" devices are a well-defined list of CDI device names which, when requested,
116140// trigger the generation of a CDI spec at runtime. This removes the need to generate a
@@ -129,35 +153,48 @@ func filterAutomaticDevices(devices []string) []string {
129153func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
130154 logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
131155
132- var identifiers []string
156+ perModeIdentifiers := make (map [string ][]string )
157+ perModeDeviceClass := map [string ]string {"auto" : automaticDeviceClass }
158+ modes := []string {"auto" }
133159 for _ , device := range devices {
134- identifiers = append (identifiers , strings .TrimPrefix (device , automaticDevicePrefix ))
160+ if strings .HasPrefix (device , "mode=" ) {
161+ modes = append (modes , strings .TrimPrefix (device , "mode=" ))
162+ continue
163+ }
164+ perModeIdentifiers ["auto" ] = append (perModeIdentifiers ["auto" ], strings .TrimPrefix (device , automaticDevicePrefix ))
135165 }
136166
137- cdilib , err := nvcdi .New (
138- nvcdi .WithLogger (logger ),
139- nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
140- nvcdi .WithDriverRoot (cfg .NVIDIAContainerCLIConfig .Root ),
141- nvcdi .WithVendor (automaticDeviceVendor ),
142- nvcdi .WithClass (automaticDeviceClass ),
143- )
144- if err != nil {
145- return nil , fmt .Errorf ("failed to construct CDI library: %w" , err )
146- }
167+ var modifiers oci.SpecModifiers
168+ for _ , mode := range modes {
169+ cdilib , err := nvcdi .New (
170+ nvcdi .WithLogger (logger ),
171+ nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
172+ nvcdi .WithDriverRoot (cfg .NVIDIAContainerCLIConfig .Root ),
173+ nvcdi .WithVendor (automaticDeviceVendor ),
174+ nvcdi .WithClass (perModeDeviceClass [mode ]),
175+ nvcdi .WithMode (mode ),
176+ )
177+ if err != nil {
178+ return nil , fmt .Errorf ("failed to construct CDI library for mode %q: %w" , mode , err )
179+ }
147180
148- spec , err := cdilib .GetSpec (identifiers ... )
149- if err != nil {
150- return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
151- }
152- cdiDeviceRequestor , err := cdi .New (
153- cdi .WithLogger (logger ),
154- cdi .WithSpec (spec .Raw ()),
155- )
156- if err != nil {
157- return nil , fmt .Errorf ("failed to construct CDI modifier: %w" , err )
181+ spec , err := cdilib .GetSpec (perModeIdentifiers [mode ]... )
182+ if err != nil {
183+ return nil , fmt .Errorf ("failed to generate CDI spec for mode %q: %w" , mode , err )
184+ }
185+
186+ cdiDeviceRequestor , err := cdi .New (
187+ cdi .WithLogger (logger ),
188+ cdi .WithSpec (spec .Raw ()),
189+ )
190+ if err != nil {
191+ return nil , fmt .Errorf ("failed to construct CDI modifier for mode %q: %w" , mode , err )
192+ }
193+
194+ modifiers = append (modifiers , cdiDeviceRequestor )
158195 }
159196
160- return cdiDeviceRequestor , nil
197+ return modifiers , nil
161198}
162199
163200type deduplicatedDeviceRequestor struct {
0 commit comments