@@ -62,7 +62,9 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
6262 return nil , fmt .Errorf ("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices" )
6363 }
6464 if len (automaticDevices ) > 0 {
65- automaticDevices = append (automaticDevices , gatedDevices (image ).DeviceRequests ()... )
65+ automaticDevices = append (automaticDevices , withUniqueDevices (gatedDevices (image )).DeviceRequests ()... )
66+ automaticDevices = append (automaticDevices , withUniqueDevices (imexDevices (image )).DeviceRequests ()... )
67+
6668 automaticModifier , err := newAutomaticCDISpecModifier (logger , cfg , automaticDevices )
6769 if err == nil {
6870 return automaticModifier , nil
@@ -135,6 +137,17 @@ func (g gatedDevices) DeviceRequests() []string {
135137 return devices
136138}
137139
140+ type imexDevices image.CUDA
141+
142+ func (d imexDevices ) DeviceRequests () []string {
143+ var devices []string
144+ i := (image .CUDA )(d )
145+ for _ , channelID := range i .ImexChannelRequests () {
146+ devices = append (devices , "mode=imex,id=" + channelID )
147+ }
148+ return devices
149+ }
150+
138151// filterAutomaticDevices searches for "automatic" device names in the input slice.
139152// "Automatic" devices are a well-defined list of CDI device names which, when requested,
140153// trigger the generation of a CDI spec at runtime. This removes the need to generate a
@@ -155,17 +168,21 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
155168
156169 perModeIdentifiers := make (map [string ][]string )
157170 perModeDeviceClass := map [string ]string {"auto" : automaticDeviceClass }
158- modes := []string {"auto" }
171+ uniqueModes := []string {"auto" }
172+ seen := make (map [string ]bool )
159173 for _ , device := range devices {
160- if strings .HasPrefix (device , "mode=" ) {
161- modes = append (modes , strings .TrimPrefix (device , "mode=" ))
162- continue
174+ mode , id := getModeIdentifier (device )
175+ if ! seen [mode ] {
176+ uniqueModes = append (uniqueModes , mode )
177+ seen [mode ] = true
178+ }
179+ if id != "" {
180+ perModeIdentifiers [id ] = append (perModeIdentifiers [id ], id )
163181 }
164- perModeIdentifiers ["auto" ] = append (perModeIdentifiers ["auto" ], strings .TrimPrefix (device , automaticDevicePrefix ))
165182 }
166183
167184 var modifiers oci.SpecModifiers
168- for _ , mode := range modes {
185+ for _ , mode := range uniqueModes {
169186 cdilib , err := nvcdi .New (
170187 nvcdi .WithLogger (logger ),
171188 nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
@@ -197,6 +214,18 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
197214 return modifiers , nil
198215}
199216
217+ func getModeIdentifier (device string ) (string , string ) {
218+ if ! strings .HasPrefix (device , "mode=" ) {
219+ return "auto" , strings .TrimPrefix (device , automaticDevicePrefix )
220+ }
221+ parts := strings .SplitN (device , "," , 2 )
222+ mode := strings .TrimPrefix (parts [0 ], "mode=" )
223+ if len (parts ) == 2 {
224+ return mode , strings .TrimPrefix (parts [1 ], "id=" )
225+ }
226+ return mode , ""
227+ }
228+
200229type deduplicatedDeviceRequestor struct {
201230 deviceRequestor
202231}
0 commit comments