@@ -108,36 +108,66 @@ func (l deviceLib) alwaysShutdown() {
108108 }
109109}
110110
111- func (l deviceLib ) enumerateAllPossibleDevices () (AllocatableDevices , error ) {
111+ func (l deviceLib ) enumerateAllPossibleDevices (config * Config ) (AllocatableDevices , error ) {
112+ alldevices := make (AllocatableDevices )
113+ deviceClasses := config .flags .deviceClasses
114+
115+ if deviceClasses .Has (GpuDeviceType ) || deviceClasses .Has (MigDeviceType ) {
116+ gms , err := l .enumerateGpusAndMigDevices (config )
117+ if err != nil {
118+ return nil , fmt .Errorf ("error enumerating IMEX devices: %w" , err )
119+ }
120+ for k , v := range gms {
121+ alldevices [k ] = v
122+ }
123+ }
124+
125+ if deviceClasses .Has (ImexChannelType ) {
126+ imex , err := l .enumerateImexChannels (config )
127+ if err != nil {
128+ return nil , fmt .Errorf ("error enumerating IMEX devices: %w" , err )
129+ }
130+ for k , v := range imex {
131+ alldevices [k ] = v
132+ }
133+ }
134+
135+ return alldevices , nil
136+ }
137+
138+ func (l deviceLib ) enumerateGpusAndMigDevices (config * Config ) (AllocatableDevices , error ) {
112139 if err := l .Init (); err != nil {
113140 return nil , err
114141 }
115142 defer l .alwaysShutdown ()
116143
117- alldevices := make (AllocatableDevices )
144+ devices := make (AllocatableDevices )
145+ deviceClasses := config .flags .deviceClasses
118146 err := l .VisitDevices (func (i int , d nvdev.Device ) error {
119147 gpuInfo , err := l .getGpuInfo (i , d )
120148 if err != nil {
121149 return fmt .Errorf ("error getting info for GPU %d: %w" , i , err )
122150 }
123151
124- migs , err := l .getMigDevices (gpuInfo )
125- if err != nil {
126- return fmt .Errorf ("error getting MIG devices for GPU %d: %w" , i , err )
127- }
128-
129- for _ , migDeviceInfo := range migs {
152+ if deviceClasses .Has (GpuDeviceType ) && ! gpuInfo .migEnabled {
130153 deviceInfo := & AllocatableDevice {
131- Mig : migDeviceInfo ,
154+ Gpu : gpuInfo ,
132155 }
133- alldevices [ migDeviceInfo .CanonicalName ()] = deviceInfo
156+ devices [ gpuInfo .CanonicalName ()] = deviceInfo
134157 }
135158
136- if ! gpuInfo .migEnabled && len (migs ) == 0 {
137- deviceInfo := & AllocatableDevice {
138- Gpu : gpuInfo ,
159+ if deviceClasses .Has (MigDeviceType ) {
160+ migs , err := l .getMigDevices (gpuInfo )
161+ if err != nil {
162+ return fmt .Errorf ("error getting MIG devices for GPU %d: %w" , i , err )
163+ }
164+
165+ for _ , migDeviceInfo := range migs {
166+ deviceInfo := & AllocatableDevice {
167+ Mig : migDeviceInfo ,
168+ }
169+ devices [migDeviceInfo .CanonicalName ()] = deviceInfo
139170 }
140- alldevices [gpuInfo .CanonicalName ()] = deviceInfo
141171 }
142172
143173 return nil
@@ -146,6 +176,12 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
146176 return nil , fmt .Errorf ("error visiting devices: %w" , err )
147177 }
148178
179+ return devices , nil
180+ }
181+
182+ func (l deviceLib ) enumerateImexChannels (config * Config ) (AllocatableDevices , error ) {
183+ devices := make (AllocatableDevices )
184+
149185 imexChannelCount , err := l .getImexChannelCount ()
150186 if err != nil {
151187 return nil , fmt .Errorf ("error getting IMEX channel count: %w" , err )
@@ -157,10 +193,10 @@ func (l deviceLib) enumerateAllPossibleDevices() (AllocatableDevices, error) {
157193 deviceInfo := & AllocatableDevice {
158194 ImexChannel : imexChannelInfo ,
159195 }
160- alldevices [imexChannelInfo .CanonicalName ()] = deviceInfo
196+ devices [imexChannelInfo .CanonicalName ()] = deviceInfo
161197 }
162198
163- return alldevices , nil
199+ return devices , nil
164200}
165201
166202func (l deviceLib ) getGpuInfo (index int , device nvdev.Device ) (* GpuInfo , error ) {
0 commit comments