@@ -28,24 +28,29 @@ import (
2828
2929 "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
3030 "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
31- "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3231)
3332
3433type nvmllib nvcdilib
3534
36- var _ Interface = (* nvmllib )(nil )
35+ var _ wrapped = (* nvmllib )(nil )
3736
38- // GetSpec should not be called for nvmllib
39- func (l * nvmllib ) GetSpec (... string ) (spec.Interface , error ) {
40- return nil , fmt .Errorf ("unexpected call to nvmllib.GetSpec()" )
41- }
37+ // GetCommonEdits generates a CDI specification that can be used for ANY devices
38+ func (l * nvmllib ) GetCommonEdits () (* cdi.ContainerEdits , error ) {
39+ common , err := l .newCommonNVMLDiscoverer ()
40+ if err != nil {
41+ return nil , fmt .Errorf ("failed to create discoverer for common entities: %v" , err )
42+ }
4243
43- // GetAllDeviceSpecs returns the device specs for all available devices.
44- func (l * nvmllib ) GetAllDeviceSpecs () ([]specs.Device , error ) {
45- var deviceSpecs []specs.Device
44+ return edits .FromDiscoverer (common )
45+ }
4646
47+ // GetDeviceSpecsByID returns the CDI device specs for the devices represented
48+ // by the requested identifiers. Here an identifier is one of the following:
49+ // * an index of a GPU or MIG device
50+ // * a UUID of a GPU or MIG device
51+ func (l * nvmllib ) GetDeviceSpecsByID (ids ... string ) ([]specs.Device , error ) {
4752 if r := l .nvmllib .Init (); r != nvml .SUCCESS {
48- return nil , fmt .Errorf ("failed to initialize NVML: %v " , r )
53+ return nil , fmt .Errorf ("failed to initialize NVML: %w " , r )
4954 }
5055 defer func () {
5156 if r := l .nvmllib .Shutdown (); r != nvml .SUCCESS {
@@ -66,93 +71,83 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
6671 }()
6772 }
6873
69- gpuDeviceSpecs , err := l .getGPUDeviceSpecs ()
70- if err != nil {
71- return nil , err
72- }
73- deviceSpecs = append (deviceSpecs , gpuDeviceSpecs ... )
74-
75- migDeviceSpecs , err := l .getMigDeviceSpecs ()
74+ generators , err := l .getDeviceSpecGeneratorsForIDs (ids ... )
7675 if err != nil {
7776 return nil , err
7877 }
79- deviceSpecs = append (deviceSpecs , migDeviceSpecs ... )
8078
81- return deviceSpecs , nil
79+ return generators . GetDeviceSpecs ()
8280}
8381
84- // GetCommonEdits generates a CDI specification that can be used for ANY devices
85- func (l * nvmllib ) GetCommonEdits () (* cdi.ContainerEdits , error ) {
86- common , err := l .newCommonNVMLDiscoverer ()
87- if err != nil {
88- return nil , fmt .Errorf ("failed to create discoverer for common entities: %v" , err )
82+ func (l * nvmllib ) newDeviceSpecGeneratorFromNVMLDevice (id string , nvmlDevice nvml.Device ) (deviceSpecGenerator , error ) {
83+ isMig , ret := nvmlDevice .IsMigDeviceHandle ()
84+ if ret != nvml .SUCCESS {
85+ return nil , ret
86+ }
87+ if isMig {
88+ return l .newMIGDeviceSpecGeneratorFromNVMLDevice (id , nvmlDevice )
8989 }
9090
91- return edits . FromDiscoverer ( common )
91+ return l . newFullGPUDeviceSpecGeneratorFromNVMLDevice ( id , nvmlDevice )
9292}
9393
94- // GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
95- // the provided identifiers, where an identifier is an index or UUID of a valid
96- // GPU device.
97- // Deprecated: Use GetDeviceSpecsBy instead.
98- func (l * nvmllib ) GetDeviceSpecsByID (ids ... string ) ([]specs.Device , error ) {
94+ func (l * nvmllib ) getDeviceSpecGeneratorsForIDs (ids ... string ) (deviceSpecGenerators , error ) {
9995 var identifiers []device.Identifier
10096 for _ , id := range ids {
101- identifiers = append (identifiers , device .Identifier (id ))
102- }
103- return l .GetDeviceSpecsBy (identifiers ... )
104- }
105-
106- // GetDeviceSpecsBy returns the device specs for devices with the specified identifiers.
107- func (l * nvmllib ) GetDeviceSpecsBy (identifiers ... device.Identifier ) ([]specs.Device , error ) {
108- for _ , id := range identifiers {
10997 if id == "all" {
110- return l .GetAllDeviceSpecs ()
98+ return l .getDeviceSpecGeneratorsForAllDevices ()
11199 }
100+ identifiers = append (identifiers , device .Identifier (id ))
112101 }
113102
114- var deviceSpecs []specs.Device
115-
116- if r := l .nvmllib .Init (); r != nvml .SUCCESS {
117- return nil , fmt .Errorf ("failed to initialize NVML: %w" , r )
103+ devices , err := l .getNVMLDevicesByID (identifiers ... )
104+ if err != nil {
105+ return nil , err
118106 }
119- defer func () {
120- if r := l .nvmllib .Shutdown (); r != nvml .SUCCESS {
121- l .logger .Warningf ("failed to shutdown NVML: %v" , r )
122- }
123- }()
124107
125- if l .nvsandboxutilslib != nil {
126- if r := l .nvsandboxutilslib .Init (l .driverRoot ); r != nvsandboxutils .SUCCESS {
127- l .logger .Warningf ("Failed to init nvsandboxutils: %v; ignoring" , r )
128- l .nvsandboxutilslib = nil
108+ var DeviceSpecGenerators deviceSpecGenerators
109+ for i , device := range devices {
110+ editor , err := l .newDeviceSpecGeneratorFromNVMLDevice (ids [i ], device )
111+ if err != nil {
112+ return nil , err
129113 }
130- defer func () {
131- if l .nvsandboxutilslib == nil {
132- return
133- }
134- _ = l .nvsandboxutilslib .Shutdown ()
135- }()
114+ DeviceSpecGenerators = append (DeviceSpecGenerators , editor )
136115 }
137116
138- nvmlDevices , err := l .getNVMLDevicesByID (identifiers ... )
117+ return DeviceSpecGenerators , nil
118+ }
119+
120+ func (l * nvmllib ) getDeviceSpecGeneratorsForAllDevices () ([]deviceSpecGenerator , error ) {
121+ var DeviceSpecGenerators []deviceSpecGenerator
122+ err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
123+ e := & fullGPUDeviceSpecGenerator {
124+ nvmllib : l ,
125+ id : fmt .Sprintf ("%d" , i ),
126+ device : d ,
127+ }
128+
129+ DeviceSpecGenerators = append (DeviceSpecGenerators , e )
130+ return nil
131+ })
139132 if err != nil {
140- return nil , fmt .Errorf ("failed to get NVML device handles : %w" , err )
133+ return nil , fmt .Errorf ("failed to get full GPU device editors : %w" , err )
141134 }
142135
143- for i , nvmlDevice := range nvmlDevices {
144- deviceEdits , err := l .getEditsForDevice (nvmlDevice )
145- if err != nil {
146- return nil , fmt .Errorf ("failed to get CDI device edits for identifier %q: %w" , identifiers [i ], err )
147- }
148- deviceSpec := specs.Device {
149- Name : string (identifiers [i ]),
150- ContainerEdits : * deviceEdits .ContainerEdits ,
136+ err = l .devicelib .VisitMigDevices (func (i int , d device.Device , j int , mig device.MigDevice ) error {
137+ e := & migDeviceSpecGenerator {
138+ nvmllib : l ,
139+ id : fmt .Sprintf ("%d:%d" , i , j ),
140+ parent : d ,
141+ device : mig ,
151142 }
152- deviceSpecs = append (deviceSpecs , deviceSpec )
143+ DeviceSpecGenerators = append (DeviceSpecGenerators , e )
144+ return nil
145+ })
146+ if err != nil {
147+ return nil , fmt .Errorf ("failed to get MIG device editors: %w" , err )
153148 }
154149
155- return deviceSpecs , nil
150+ return DeviceSpecGenerators , nil
156151}
157152
158153// TODO: move this to go-nvlib?
@@ -201,76 +196,21 @@ func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
201196 return nil , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
202197}
203198
204- func (l * nvmllib ) getEditsForDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
205- mig , err := nvmlDevice .IsMigDeviceHandle ()
206- if err != nvml .SUCCESS {
207- return nil , fmt .Errorf ("failed to determine if device handle is a MIG device: %w" , err )
208- }
209- if mig {
210- return l .getEditsForMIGDevice (nvmlDevice )
211- }
212- return l .getEditsForGPUDevice (nvmlDevice )
213- }
214-
215- func (l * nvmllib ) getEditsForGPUDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
216- nvlibDevice , err := l .devicelib .NewDevice (nvmlDevice )
217- if err != nil {
218- return nil , fmt .Errorf ("failed to construct device: %w" , err )
219- }
220- deviceEdits , err := l .GetGPUDeviceEdits (nvlibDevice )
221- if err != nil {
222- return nil , fmt .Errorf ("failed to get GPU device edits: %w" , err )
223- }
199+ type deviceSpecGenerators []deviceSpecGenerator
224200
225- return deviceEdits , nil
226- }
227-
228- func (l * nvmllib ) getEditsForMIGDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
229- nvmlParentDevice , ret := nvmlDevice .GetDeviceHandleFromMigDeviceHandle ()
230- if ret != nvml .SUCCESS {
231- return nil , fmt .Errorf ("failed to get parent device handle: %w" , ret )
232- }
233- nvlibMigDevice , err := l .devicelib .NewMigDevice (nvmlDevice )
234- if err != nil {
235- return nil , fmt .Errorf ("failed to construct device: %w" , err )
236- }
237- nvlibParentDevice , err := l .devicelib .NewDevice (nvmlParentDevice )
238- if err != nil {
239- return nil , fmt .Errorf ("failed to construct parent device: %w" , err )
240- }
241- return l .GetMIGDeviceEdits (nvlibParentDevice , nvlibMigDevice )
242- }
243-
244- func (l * nvmllib ) getGPUDeviceSpecs () ([]specs.Device , error ) {
245- var deviceSpecs []specs.Device
246- err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
247- specsForDevice , err := l .GetGPUDeviceSpecs (i , d )
248- if err != nil {
249- return err
201+ // GetDeviceSpecs returns the combined specs for each device spec generator.
202+ func (g deviceSpecGenerators ) GetDeviceSpecs () ([]specs.Device , error ) {
203+ var allDeviceSpecs []specs.Device
204+ for _ , dsg := range g {
205+ if dsg == nil {
206+ continue
250207 }
251- deviceSpecs = append (deviceSpecs , specsForDevice ... )
252-
253- return nil
254- })
255- if err != nil {
256- return nil , fmt .Errorf ("failed to generate CDI edits for GPU devices: %v" , err )
257- }
258- return deviceSpecs , err
259- }
260-
261- func (l * nvmllib ) getMigDeviceSpecs () ([]specs.Device , error ) {
262- var deviceSpecs []specs.Device
263- err := l .devicelib .VisitMigDevices (func (i int , d device.Device , j int , mig device.MigDevice ) error {
264- specsForDevice , err := l .GetMIGDeviceSpecs (i , d , j , mig )
208+ deviceSpecs , err := dsg .GetDeviceSpecs ()
265209 if err != nil {
266- return err
210+ return nil , err
267211 }
268- deviceSpecs = append (deviceSpecs , specsForDevice ... )
269-
270- return nil
271- })
272- if err != nil {
273- return nil , fmt .Errorf ("failed to generate CDI edits for GPU devices: %v" , err )
212+ allDeviceSpecs = append (allDeviceSpecs , deviceSpecs ... )
274213 }
275- return deviceSpecs , err
214+
215+ return allDeviceSpecs , nil
276216}
0 commit comments