Skip to content

Commit 1198f76

Browse files
committed
Merge branch 'cherry-picks-v0.4.3' into 'release-0.4'
Cherry-picks of required patches for v0.4.3 See merge request nvidia/cloud-native/mig-parted!106
2 parents 5ea48f5 + a3865a1 commit 1198f76

File tree

7 files changed

+85
-5
lines changed

7 files changed

+85
-5
lines changed

deployments/systemd/nvidia-mig-manager.service

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Description=Configure MIG on NVIDIA GPUs
1717
DefaultDependencies=no
1818
After=sysinit.target local-fs.target
19-
Before=basic.target nvidia-persistenced.service
19+
Before=basic.target nvidia-persistenced.service systemd-resolved.service
2020

2121
[Service]
2222
Type=oneshot

internal/nvml/mock.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type MockA100Device struct {
2929
MigMode int
3030
GpuInstances map[*MockA100GpuInstance]struct{}
3131
GpuInstanceCounter uint32
32+
MemoryInfo Memory
3233
}
3334
type MockA100GpuInstance struct {
3435
Info GpuInstanceInfo
@@ -330,6 +331,7 @@ func NewMockA100Device(index int) Device {
330331
Index: index,
331332
GpuInstances: make(map[*MockA100GpuInstance]struct{}),
332333
GpuInstanceCounter: 0,
334+
MemoryInfo: Memory{42949672960, 0, 0},
333335
}
334336
}
335337

@@ -387,6 +389,10 @@ func (d *MockA100Device) GetUUID() (string, Return) {
387389
return d.UUID, MockReturn(SUCCESS)
388390
}
389391

392+
func (d *MockA100Device) GetMemoryInfo() (Memory, Return) {
393+
return d.MemoryInfo, MockReturn(SUCCESS)
394+
}
395+
390396
func (d *MockA100Device) GetPciInfo() (PciInfo, Return) {
391397
p := PciInfo{
392398
PciDeviceId: 0x20B010DE,

internal/nvml/nvml.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ func (d nvmlDevice) GetUUID() (string, Return) {
7272
return u, nvmlReturn(r)
7373
}
7474

75+
func (d nvmlDevice) GetMemoryInfo() (Memory, Return) {
76+
m, r := nvml.Device(d).GetMemoryInfo()
77+
return Memory(m), nvmlReturn(r)
78+
}
79+
7580
func (d nvmlDevice) GetPciInfo() (PciInfo, Return) {
7681
p, r := nvml.Device(d).GetPciInfo()
7782
return PciInfo(p), nvmlReturn(r)

internal/nvml/types.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type Interface interface {
3232
type Device interface {
3333
GetIndex() (int, Return)
3434
GetUUID() (string, Return)
35+
GetMemoryInfo() (Memory, Return)
3536
GetPciInfo() (PciInfo, Return)
3637
SetMigMode(Mode int) (Return, Return)
3738
GetMigMode() (int, int, Return)
@@ -69,6 +70,7 @@ type ComputeInstanceInfo struct {
6970
Placement ComputeInstancePlacement
7071
}
7172

73+
type Memory nvml.Memory
7274
type PciInfo nvml.PciInfo
7375
type GpuInstanceProfileInfo nvml.GpuInstanceProfileInfo
7476
type GpuInstancePlacement nvml.GpuInstancePlacement

pkg/mig/config/config.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ func (m *nvmlMigConfigManager) GetMigConfig(gpu int) (types.MigConfig, error) {
6666
return nil, fmt.Errorf("error getting device handle: %v", ret)
6767
}
6868

69+
deviceMemory, ret := device.GetMemoryInfo()
70+
if ret.Value() != nvml.SUCCESS {
71+
return nil, fmt.Errorf("error getting device memory: %v", ret)
72+
}
73+
6974
err := m.nvlib.Mig.Device(device).AssertMigEnabled()
7075
if err != nil {
7176
return nil, fmt.Errorf("error asserting MIG enabled: %v", err)
@@ -74,7 +79,7 @@ func (m *nvmlMigConfigManager) GetMigConfig(gpu int) (types.MigConfig, error) {
7479
migConfig := types.MigConfig{}
7580
err = m.nvlib.Mig.Device(device).WalkGpuInstances(func(gi nvml.GpuInstance, giProfileID int, giProfileInfo nvml.GpuInstanceProfileInfo) error {
7681
err := m.nvlib.Mig.GpuInstance(gi).WalkComputeInstances(func(ci nvml.ComputeInstance, ciProfileID int, ciEngProfileID int, ciProfileInfo nvml.ComputeInstanceProfileInfo) error {
77-
mp := types.NewMigProfile(giProfileID, ciProfileID, ciEngProfileID, &giProfileInfo, &ciProfileInfo)
82+
mp := types.NewMigProfile(giProfileID, ciProfileID, ciEngProfileID, &giProfileInfo, &ciProfileInfo, deviceMemory.Total)
7883
migConfig[mp.String()]++
7984
return nil
8085
})
@@ -102,6 +107,11 @@ func (m *nvmlMigConfigManager) SetMigConfig(gpu int, config types.MigConfig) err
102107
return fmt.Errorf("error getting device handle: %v", ret)
103108
}
104109

110+
deviceMemory, ret := device.GetMemoryInfo()
111+
if ret.Value() != nvml.SUCCESS {
112+
return fmt.Errorf("error getting device memory: %v", ret)
113+
}
114+
105115
err := m.nvlib.Mig.Device(device).AssertMigEnabled()
106116
if err != nil {
107117
return fmt.Errorf("error asserting MIG enabled: %v", err)
@@ -169,7 +179,7 @@ func (m *nvmlMigConfigManager) SetMigConfig(gpu int, config types.MigConfig) err
169179
return fmt.Errorf("error creating Compute instance for '%v': %v", mp, ret)
170180
}
171181

172-
valid := types.NewMigProfile(mp.GIProfileID, mp.CIProfileID, mp.CIEngProfileID, &giProfileInfo, &ciProfileInfo)
182+
valid := types.NewMigProfile(mp.GIProfileID, mp.CIProfileID, mp.CIEngProfileID, &giProfileInfo, &ciProfileInfo, deviceMemory.Total)
173183
if !mp.Equals(valid) {
174184
if reuseGI {
175185
reuseGI = false

pkg/types/mig_profile.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package types
1818

1919
import (
2020
"fmt"
21+
"math"
2122
"strconv"
2223
"strings"
2324

@@ -41,11 +42,11 @@ type MigProfile struct {
4142
}
4243

4344
// NewMigProfile constructs a new MigProfile struct using info from the giProfiles and ciProfiles used to create it.
44-
func NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, giProfileInfo *nvml.GpuInstanceProfileInfo, ciProfileInfo *nvml.ComputeInstanceProfileInfo) *MigProfile {
45+
func NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, giProfileInfo *nvml.GpuInstanceProfileInfo, ciProfileInfo *nvml.ComputeInstanceProfileInfo, totalDeviceMemory uint64) *MigProfile {
4546
return &MigProfile{
4647
C: int(ciProfileInfo.SliceCount),
4748
G: int(giProfileInfo.SliceCount),
48-
GB: int((giProfileInfo.MemorySizeMB + 1024 - 1) / 1024),
49+
GB: int(getMigMemorySizeInGB(totalDeviceMemory, giProfileInfo.MemorySizeMB)),
4950
GIProfileID: giProfileID,
5051
CIProfileID: ciProfileID,
5152
CIEngProfileID: ciEngProfileID,
@@ -256,3 +257,13 @@ func parseMigProfileAttributes(s string) ([]string, error) {
256257
}
257258
return attr, nil
258259
}
260+
261+
func getMigMemorySizeInGB(totalDeviceMemory, migMemorySizeMB uint64) uint64 {
262+
const fracDenominator = 8
263+
const oneMB = 1024 * 1024
264+
const oneGB = 1024 * 1024 * 1024
265+
fractionalGpuMem := (float64(migMemorySizeMB) * oneMB) / float64(totalDeviceMemory)
266+
fractionalGpuMem = math.Ceil(fractionalGpuMem*fracDenominator) / fracDenominator
267+
totalMemGB := float64((totalDeviceMemory + oneGB - 1) / oneGB)
268+
return uint64(math.Round(fractionalGpuMem * totalMemGB))
269+
}

pkg/types/types_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package types
1818

1919
import (
20+
"fmt"
2021
"testing"
2122

2223
"github.com/stretchr/testify/require"
@@ -266,3 +267,48 @@ func TestParseMigProfile(t *testing.T) {
266267
})
267268
}
268269
}
270+
271+
func TestGetMigMemorySizeInGB(t *testing.T) {
272+
type testCase struct {
273+
totalDeviceMemory uint64
274+
migMemorySizeMB uint64
275+
expectedMemorySizeGB uint64
276+
}
277+
278+
const maxMemorySlices = 8
279+
const oneMB = uint64(1024 * 1024)
280+
const oneGB = uint64(1024 * 1024 * 1024)
281+
282+
totalDeviceMemory := []uint64{
283+
24 * oneGB,
284+
40 * oneGB,
285+
80 * oneGB,
286+
}
287+
288+
testCases := []testCase{}
289+
for _, tdm := range totalDeviceMemory {
290+
sliceSize := tdm / maxMemorySlices
291+
292+
const stepSize = oneGB / 32
293+
for i := stepSize; i <= tdm; i += stepSize {
294+
tc := testCase{
295+
totalDeviceMemory: tdm,
296+
migMemorySizeMB: i / oneMB,
297+
}
298+
for j := uint64(sliceSize); j <= tdm; j += sliceSize {
299+
if i <= j {
300+
tc.expectedMemorySizeGB = j / oneGB
301+
break
302+
}
303+
}
304+
testCases = append(testCases, tc)
305+
}
306+
}
307+
308+
for _, tc := range testCases {
309+
t.Run(fmt.Sprintf("%v", tc.migMemorySizeMB), func(t *testing.T) {
310+
memorySizeGB := getMigMemorySizeInGB(tc.totalDeviceMemory, tc.migMemorySizeMB)
311+
require.Equal(t, int(tc.expectedMemorySizeGB), int(memorySizeGB))
312+
})
313+
}
314+
}

0 commit comments

Comments
 (0)