@@ -18,20 +18,49 @@ package nvcdi
1818
1919import (
2020 "fmt"
21+ "slices"
22+ "strconv"
23+ "strings"
2124
2225 "tags.cncf.io/container-device-interface/pkg/cdi"
2326 "tags.cncf.io/container-device-interface/specs-go"
2427
28+ "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
29+ "github.com/NVIDIA/go-nvml/pkg/nvml"
30+ "github.com/google/uuid"
31+
2532 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2633 "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2734 "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
2835)
2936
3037type csvlib nvcdilib
3138
39+ type mixedcsvlib nvcdilib
40+
3241var _ deviceSpecGeneratorFactory = (* csvlib )(nil )
3342
43+ // DeviceSpecGenerators creates a set of generators for the specified set of
44+ // devices.
45+ // If NVML is not available or the disable-multiple-csv-devices feature flag is
46+ // enabled, a single device is assumed.
3447func (l * csvlib ) DeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
48+ if l .featureFlags [FeatureDisableMultipleCSVDevices ] {
49+ return l .purecsvDeviceSpecGenerators (ids ... )
50+ }
51+ hasNVML , _ := l .infolib .HasNvml ()
52+ if ! hasNVML {
53+ return l .purecsvDeviceSpecGenerators (ids ... )
54+ }
55+ mixed , err := l .mixedDeviceSpecGenerators (ids ... )
56+ if err != nil {
57+ l .logger .Warningf ("Failed to create mixed CSV spec generator; falling back to pure CSV implementation: %v" , err )
58+ return l .purecsvDeviceSpecGenerators (ids ... )
59+ }
60+ return mixed , nil
61+ }
62+
63+ func (l * csvlib ) purecsvDeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
3564 for _ , id := range ids {
3665 switch id {
3766 case "all" :
@@ -40,12 +69,41 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
4069 return nil , fmt .Errorf ("unsupported device id: %v" , id )
4170 }
4271 }
72+ g := & csvDeviceGenerator {
73+ csvlib : l ,
74+ index : 0 ,
75+ uuid : "" ,
76+ }
77+ return g , nil
78+ }
79+
80+ func (l * csvlib ) mixedDeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
81+ return (* mixedcsvlib )(l ).DeviceSpecGenerators (ids ... )
82+ }
4383
44- return l , nil
84+ // A csvDeviceGenerator generates CDI specs for a device based on a set of
85+ // platform-specific CSV files.
86+ type csvDeviceGenerator struct {
87+ * csvlib
88+ index int
89+ uuid string
90+ onlyDeviceNodes []string
91+ additionalDeviceNodes []string
92+ }
93+
94+ func (l * csvDeviceGenerator ) GetUUID () (string , error ) {
95+ return l .uuid , nil
4596}
4697
4798// GetDeviceSpecs returns the CDI device specs for a single device.
48- func (l * csvlib ) GetDeviceSpecs () ([]specs.Device , error ) {
99+ func (l * csvDeviceGenerator ) GetDeviceSpecs () ([]specs.Device , error ) {
100+ mountSpecs := tegra .MountSpecsFromCSVFiles (l .logger , l .csvFiles ... )
101+ if len (l .onlyDeviceNodes ) > 0 {
102+ mountSpecs = tegra .Merge (
103+ tegra .WithoutRegularDeviceNodes (mountSpecs ),
104+ tegra .DeviceNodes (l .onlyDeviceNodes ... ),
105+ )
106+ }
49107 d , err := tegra .New (
50108 tegra .WithLogger (l .logger ),
51109 tegra .WithDriverRoot (l .driverRoot ),
@@ -55,8 +113,13 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
55113 tegra .WithLibrarySearchPaths (l .librarySearchPaths ... ),
56114 tegra .WithMountSpecsByPath (
57115 tegra .Filter (
58- tegra .MountSpecsFromCSVFiles (l .logger , l .csvFiles ... ),
59- tegra .Symlinks (l .csvIgnorePatterns ... ),
116+ tegra .Merge (
117+ mountSpecs ,
118+ tegra .DeviceNodes (l .additionalDeviceNodes ... ),
119+ ),
120+ tegra .Merge (
121+ tegra .Symlinks (l .csvIgnorePatterns ... ),
122+ ),
60123 ),
61124 ),
62125 )
@@ -68,7 +131,7 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
68131 return nil , fmt .Errorf ("failed to create container edits for CSV files: %v" , err )
69132 }
70133
71- names , err := l .deviceNamers .GetDeviceNames (0 , uuidIgnored {} )
134+ names , err := l .deviceNamers .GetDeviceNames (l . index , l )
72135 if err != nil {
73136 return nil , fmt .Errorf ("failed to get device name: %v" , err )
74137 }
@@ -88,3 +151,145 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
88151func (l * csvlib ) GetCommonEdits () (* cdi.ContainerEdits , error ) {
89152 return edits .FromDiscoverer (discover.None {})
90153}
154+
155+ func (l * mixedcsvlib ) DeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
156+ asNvmlLib := (* nvmllib )(l )
157+ err := asNvmlLib .init ()
158+ if err != nil {
159+ return nil , fmt .Errorf ("failed to initialize nvml: %w" , err )
160+ }
161+ defer asNvmlLib .tryShutdown ()
162+
163+ if slices .Contains (ids , "all" ) {
164+ ids , err = l .getAllDeviceIndices ()
165+ if err != nil {
166+ return nil , fmt .Errorf ("failed to get device indices: %w" , err )
167+ }
168+ }
169+
170+ var DeviceSpecGenerators DeviceSpecGenerators
171+ for _ , id := range ids {
172+ generator , err := l .deviceSpecGeneratorForId (device .Identifier (id ))
173+ if err != nil {
174+ return nil , fmt .Errorf ("failed to create device spec generator for device %q: %w" , id , err )
175+ }
176+ DeviceSpecGenerators = append (DeviceSpecGenerators , generator )
177+ }
178+
179+ return DeviceSpecGenerators , nil
180+ }
181+
182+ func (l * mixedcsvlib ) getAllDeviceIndices () ([]string , error ) {
183+ numDevices , ret := l .nvmllib .DeviceGetCount ()
184+ if ret != nvml .SUCCESS {
185+ return nil , fmt .Errorf ("faled to get device count: %v" , ret )
186+ }
187+
188+ var allIndices []string
189+ for index := range numDevices {
190+ allIndices = append (allIndices , fmt .Sprintf ("%d" , index ))
191+ }
192+ return allIndices , nil
193+ }
194+
195+ func (l * mixedcsvlib ) deviceSpecGeneratorForId (id device.Identifier ) (DeviceSpecGenerator , error ) {
196+ switch {
197+ case id .IsGpuUUID (), isIntegratedGPUID (id ):
198+ uuid := string (id )
199+ device , ret := l .nvmllib .DeviceGetHandleByUUID (uuid )
200+ if ret != nvml .SUCCESS {
201+ return nil , fmt .Errorf ("failed to get device handle from UUID %q: %v" , uuid , ret )
202+ }
203+ index , ret := device .GetIndex ()
204+ if ret != nvml .SUCCESS {
205+ return nil , fmt .Errorf ("failed to get device index: %v" , ret )
206+ }
207+ return l .csvDeviceSpecGenerator (index , uuid , device )
208+ case id .IsGpuIndex ():
209+ index , err := strconv .Atoi (string (id ))
210+ if err != nil {
211+ return nil , fmt .Errorf ("failed to convert device index to an int: %w" , err )
212+ }
213+ device , ret := l .nvmllib .DeviceGetHandleByIndex (index )
214+ if ret != nvml .SUCCESS {
215+ return nil , fmt .Errorf ("failed to get device handle from index: %v" , ret )
216+ }
217+ uuid , ret := device .GetUUID ()
218+ if ret != nvml .SUCCESS {
219+ return nil , fmt .Errorf ("failed to get UUID: %v" , ret )
220+ }
221+ return l .csvDeviceSpecGenerator (index , uuid , device )
222+ case id .IsMigUUID ():
223+ fallthrough
224+ case id .IsMigIndex ():
225+ return nil , fmt .Errorf ("generating a CDI spec for MIG id %q is not supported in CSV mode" , id )
226+ }
227+ return nil , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
228+ }
229+
230+ func (l * mixedcsvlib ) csvDeviceSpecGenerator (index int , uuid string , device nvml.Device ) (DeviceSpecGenerator , error ) {
231+ var additionalDeviceNodes []string
232+ isIntegrated , err := isIntegratedGPU (device )
233+ if err != nil {
234+ return nil , fmt .Errorf ("is-integrated check failed for device (index=%v,uuid=%v)" , index , uuid )
235+ }
236+ if ! isIntegrated {
237+ additionalDeviceNodes = []string {
238+ "/dev/nvidia-uvm" ,
239+ "/dev/nvidia-uvm-tools" ,
240+ }
241+ }
242+ g := & csvDeviceGenerator {
243+ csvlib : (* csvlib )(l ),
244+ index : index ,
245+ uuid : uuid ,
246+ onlyDeviceNodes : []string {fmt .Sprintf ("/dev/nvidia%d" , index )},
247+ additionalDeviceNodes : additionalDeviceNodes ,
248+ }
249+ return g , nil
250+ }
251+
252+ func isIntegratedGPUID (id device.Identifier ) bool {
253+ _ , err := uuid .Parse (string (id ))
254+ return err == nil
255+ }
256+
257+ // isIntegratedGPU checks whether the specified device is an integrated GPU.
258+ // As a proxy we check the PCI Bus if for thes
259+ // TODO: This should be replaced by an explicit NVML call once available.
260+ func isIntegratedGPU (d nvml.Device ) (bool , error ) {
261+ pciInfo , ret := d .GetPciInfo ()
262+ if ret == nvml .ERROR_NOT_SUPPORTED {
263+ name , ret := d .GetName ()
264+ if ret != nvml .SUCCESS {
265+ return false , fmt .Errorf ("failed to get device name: %v" , ret )
266+ }
267+ return isIntegratedGPUName (name ), nil
268+ }
269+ if ret != nvml .SUCCESS {
270+ return false , fmt .Errorf ("failed to get PCI info: %v" , ret )
271+ }
272+
273+ if pciInfo .Domain != 0 {
274+ return false , nil
275+ }
276+ if pciInfo .Bus != 1 {
277+ return false , nil
278+ }
279+ return pciInfo .Device == 0 , nil
280+ }
281+
282+ // isIntegratedGPUName returns true if the specified device name is associated
283+ // with a known iGPU.
284+ //
285+ // TODO: Consider making go-nvlib/pkg/nvlib/info/isIntegratedGPUName public
286+ // instead.
287+ func isIntegratedGPUName (name string ) bool {
288+ if strings .Contains (name , "(nvgpu)" ) {
289+ return true
290+ }
291+ if strings .Contains (name , "NVIDIA Thor" ) {
292+ return true
293+ }
294+ return false
295+ }
0 commit comments