@@ -66,6 +66,11 @@ func (m *nvmlMigConfigManager) GetMigConfig(gpu int) (types.MigConfig, error) {
6666 return nil , fmt .Errorf ("error getting device handle: %v" , ret )
6767 }
6868
69+ deviceMemory , ret := device .GetMemoryInfo ()
70+ if ret .Value () != nvml .SUCCESS {
71+ return nil , fmt .Errorf ("error getting device memory: %v" , ret )
72+ }
73+
6974 err := m .nvlib .Mig .Device (device ).AssertMigEnabled ()
7075 if err != nil {
7176 return nil , fmt .Errorf ("error asserting MIG enabled: %v" , err )
@@ -74,7 +79,7 @@ func (m *nvmlMigConfigManager) GetMigConfig(gpu int) (types.MigConfig, error) {
7479 migConfig := types.MigConfig {}
7580 err = m .nvlib .Mig .Device (device ).WalkGpuInstances (func (gi nvml.GpuInstance , giProfileID int , giProfileInfo nvml.GpuInstanceProfileInfo ) error {
7681 err := m .nvlib .Mig .GpuInstance (gi ).WalkComputeInstances (func (ci nvml.ComputeInstance , ciProfileID int , ciEngProfileID int , ciProfileInfo nvml.ComputeInstanceProfileInfo ) error {
77- mp := types .NewMigProfile (giProfileID , ciProfileID , ciEngProfileID , & giProfileInfo , & ciProfileInfo )
82+ mp := types .NewMigProfile (giProfileID , ciProfileID , ciEngProfileID , & giProfileInfo , & ciProfileInfo , deviceMemory . Total )
7883 migConfig [mp .String ()]++
7984 return nil
8085 })
@@ -102,6 +107,11 @@ func (m *nvmlMigConfigManager) SetMigConfig(gpu int, config types.MigConfig) err
102107 return fmt .Errorf ("error getting device handle: %v" , ret )
103108 }
104109
110+ deviceMemory , ret := device .GetMemoryInfo ()
111+ if ret .Value () != nvml .SUCCESS {
112+ return fmt .Errorf ("error getting device memory: %v" , ret )
113+ }
114+
105115 err := m .nvlib .Mig .Device (device ).AssertMigEnabled ()
106116 if err != nil {
107117 return fmt .Errorf ("error asserting MIG enabled: %v" , err )
@@ -169,7 +179,7 @@ func (m *nvmlMigConfigManager) SetMigConfig(gpu int, config types.MigConfig) err
169179 return fmt .Errorf ("error creating Compute instance for '%v': %v" , mp , ret )
170180 }
171181
172- valid := types .NewMigProfile (mp .GIProfileID , mp .CIProfileID , mp .CIEngProfileID , & giProfileInfo , & ciProfileInfo )
182+ valid := types .NewMigProfile (mp .GIProfileID , mp .CIProfileID , mp .CIEngProfileID , & giProfileInfo , & ciProfileInfo , deviceMemory . Total )
173183 if ! mp .Equals (valid ) {
174184 if reuseGI {
175185 reuseGI = false
0 commit comments