@@ -72,18 +72,22 @@ func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (DeviceSpecGenera
7272 identifiers = append (identifiers , device .Identifier (id ))
7373 }
7474
75- devices , err := l .getNVMLDevicesByID (identifiers ... )
75+ uuids , err := l .normalizeDeviceIDs (identifiers ... )
7676 if err != nil {
7777 return nil , err
7878 }
7979
8080 var DeviceSpecGenerators DeviceSpecGenerators
81- for i , device := range devices {
82- editor , err := l .newDeviceSpecGeneratorFromNVMLDevice (ids [i ], device )
81+ for _ , uuid := range uuids {
82+ device , ret := l .nvmllib .DeviceGetHandleByUUID (string (uuid ))
83+ if ret != nvml .SUCCESS {
84+ return nil , fmt .Errorf ("failed to get device handle from UUID: %v" , ret )
85+ }
86+ generator , err := l .newDeviceSpecGeneratorFromNVMLDevice (string (uuid ), device )
8387 if err != nil {
8488 return nil , err
8589 }
86- DeviceSpecGenerators = append (DeviceSpecGenerators , editor )
90+ DeviceSpecGenerators = append (DeviceSpecGenerators , generator )
8791 }
8892
8993 return DeviceSpecGenerators , nil
@@ -92,7 +96,7 @@ func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (DeviceSpecGenera
9296func (l * nvmllib ) newDeviceSpecGeneratorFromNVMLDevice (id string , nvmlDevice nvml.Device ) (DeviceSpecGenerator , error ) {
9397 isMig , ret := nvmlDevice .IsMigDeviceHandle ()
9498 if ret != nvml .SUCCESS {
95- return nil , ret
99+ return nil , fmt . Errorf ( "%v" , ret )
96100 }
97101 if isMig {
98102 return l .newMIGDeviceSpecGeneratorFromNVMLDevice (id , nvmlDevice )
@@ -114,33 +118,23 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
114118 if isMigEnabled {
115119 return nil
116120 }
117- e := & fullGPUDeviceSpecGenerator {
118- nvmllib : l ,
119- id : fmt .Sprintf ("%d" , i ),
120- index : i ,
121- device : d ,
121+ fullGPU , err := l .newFullGPUDeviceSpecGeneratorFromDevice (i , d )
122+ if err != nil {
123+ return err
122124 }
123-
124- DeviceSpecGenerators = append (DeviceSpecGenerators , e )
125+ DeviceSpecGenerators = append (DeviceSpecGenerators , fullGPU )
125126 return nil
126127 })
127128 if err != nil {
128129 return nil , fmt .Errorf ("failed to get full GPU device editors: %w" , err )
129130 }
130131
131132 err = l .devicelib .VisitMigDevices (func (i int , d device.Device , j int , mig device.MigDevice ) error {
132- parentGenerator := & fullGPUDeviceSpecGenerator {
133- nvmllib : l ,
134- index : i ,
135- id : fmt .Sprintf ("%d:%d" , i , j ),
136- device : d ,
137- }
138- migGenerator := & migDeviceSpecGenerator {
139- fullGPUDeviceSpecGenerator : parentGenerator ,
140- migIndex : j ,
141- migDevice : mig ,
142- }
143- DeviceSpecGenerators = append (DeviceSpecGenerators , parentGenerator , migGenerator )
133+ migDevice , err := l .newMIGDeviceSpecGeneratorFromDevice (i , d , j , mig )
134+ if err != nil {
135+ return err
136+ }
137+ DeviceSpecGenerators = append (DeviceSpecGenerators , migDevice )
144138 return nil
145139 })
146140 if err != nil {
@@ -151,50 +145,68 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
151145}
152146
153147// TODO: move this to go-nvlib?
154- func (l * nvmllib ) getNVMLDevicesByID (identifiers ... device.Identifier ) ([]nvml.Device , error ) {
155- var devices []nvml.Device
148+ // normalizeDeviceID returns the UUIDs of the devices specified by the identifier.
149+ func (l * nvmllib ) normalizeDeviceIDs (identifiers ... device.Identifier ) ([]device.Identifier , error ) {
150+ var uuids []device.Identifier
156151 for _ , id := range identifiers {
157- dev , err := l .getNVMLDeviceByID (id )
158- if err != nvml . SUCCESS {
159- return nil , fmt . Errorf ( "failed to get NVML device handle for identifier %q: %w" , id , err )
152+ uuid , err := l .normalizeDeviceID (id )
153+ if err != nil {
154+ return nil , err
160155 }
161- devices = append (devices , dev )
156+ uuids = append (uuids , uuid )
162157 }
163- return devices , nil
158+ return uuids , nil
164159}
165160
166- func (l * nvmllib ) getNVMLDeviceByID (id device.Identifier ) (nvml. Device , error ) {
161+ func (l * nvmllib ) normalizeDeviceID (id device.Identifier ) (device. Identifier , error ) {
167162 var err error
168163
169164 if id .IsUUID () {
170- return l . nvmllib . DeviceGetHandleByUUID ( string ( id ))
165+ return id , nil
171166 }
172167
173168 if id .IsGpuIndex () {
174- if idx , err := strconv .Atoi (string (id )); err == nil {
175- return l .nvmllib .DeviceGetHandleByIndex (idx )
169+ idx , err := strconv .Atoi (string (id ))
170+ if err != nil {
171+ return "" , fmt .Errorf ("failed to convert device index to an int: %w" , err )
172+ }
173+ dev , ret := l .nvmllib .DeviceGetHandleByIndex (idx )
174+ if ret != nvml .SUCCESS {
175+ return "" , fmt .Errorf ("failed to get device handle from index: %v" , ret )
176+ }
177+ uuid , ret := dev .GetUUID ()
178+ if ret != nvml .SUCCESS {
179+ return "" , fmt .Errorf ("failed to get device UUID: %v" , ret )
176180 }
177- return nil , fmt . Errorf ( "failed to convert device index to an int: %w" , err )
181+ return device . Identifier ( uuid ), nil
178182 }
179183
180184 if id .IsMigIndex () {
181185 var gpuIdx , migIdx int
182186 var parent nvml.Device
183187 split := strings .SplitN (string (id ), ":" , 2 )
184188 if gpuIdx , err = strconv .Atoi (split [0 ]); err != nil {
185- return nil , fmt .Errorf ("failed to convert device index to an int: %w" , err )
189+ return "" , fmt .Errorf ("failed to convert device index to an int: %w" , err )
186190 }
187191 if migIdx , err = strconv .Atoi (split [1 ]); err != nil {
188- return nil , fmt .Errorf ("failed to convert device index to an int: %w" , err )
192+ return "" , fmt .Errorf ("failed to convert device index to an int: %w" , err )
189193 }
190194 parent , ret := l .nvmllib .DeviceGetHandleByIndex (gpuIdx )
191195 if ret != nvml .SUCCESS {
192- return nil , fmt .Errorf ("failed to get parent device handle: %v" , ret )
196+ return "" , fmt .Errorf ("failed to get parent device handle: %v" , ret )
197+ }
198+ mig , ret := parent .GetMigDeviceHandleByIndex (migIdx )
199+ if ret != nvml .SUCCESS {
200+ return "" , fmt .Errorf ("failed to get MIG handle by index: %v" , ret )
201+ }
202+ uuid , ret := mig .GetUUID ()
203+ if ret != nvml .SUCCESS {
204+ return "" , fmt .Errorf ("failed to get MIG UUID: %v" , ret )
193205 }
194- return parent . GetMigDeviceHandleByIndex ( migIdx )
206+ return device . Identifier ( uuid ), nil
195207 }
196208
197- return nil , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
209+ return "" , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
198210}
199211
200212func (l * nvmllib ) init () error {
0 commit comments