Skip to content

Commit 62794e9

Browse files
committed
[no-relnote] Fix CDI spec generation
These changes fix a bug in CDI spec generation introduced in #1166 where device handles become invalid when nvml is shutdown and initialized again. Here we explicitly store UUIDs and use these to query the device handles when generating the CDI specification. Signed-off-by: Evan Lezar <[email protected]>
1 parent ac49ea2 commit 62794e9

File tree

4 files changed

+163
-83
lines changed

4 files changed

+163
-83
lines changed

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,48 @@ import (
3434
// single full GPU.
3535
type fullGPUDeviceSpecGenerator struct {
3636
*nvmllib
37-
id string
38-
index int
39-
device device.Device
37+
uuid string
38+
index int
4039
}
4140

4241
var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
4342

44-
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
45-
device, err := l.devicelib.NewDevice(nvmlDevice)
46-
if err != nil {
47-
return nil, err
43+
func (l *fullGPUDeviceSpecGenerator) GetUUID() (string, error) {
44+
return l.uuid, nil
45+
}
46+
47+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromDevice(index int, d device.Device) (*fullGPUDeviceSpecGenerator, error) {
48+
uuid, ret := d.GetUUID()
49+
if ret != nvml.SUCCESS {
50+
return nil, fmt.Errorf("failed to get device UUID: %v", ret)
51+
}
52+
e := &fullGPUDeviceSpecGenerator{
53+
nvmllib: l,
54+
uuid: uuid,
55+
index: index,
4856
}
4957

58+
return e, nil
59+
}
60+
61+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(uuid string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
5062
index, ret := nvmlDevice.GetIndex()
5163
if ret != nvml.SUCCESS {
5264
return nil, fmt.Errorf("failed to get device index: %v", ret)
5365
}
5466

5567
e := &fullGPUDeviceSpecGenerator{
5668
nvmllib: l,
57-
id: id,
69+
uuid: uuid,
5870
index: index,
59-
device: device,
6071
}
6172
return e, nil
6273
}
6374

6475
func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
6576
deviceEdits, err := l.getDeviceEdits()
6677
if err != nil {
67-
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", l.id, err)
78+
return nil, fmt.Errorf("failed to get CDI device edits: %w", err)
6879
}
6980

7081
names, err := l.getNames()
@@ -84,14 +95,23 @@ func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
8495
return deviceSpecs, nil
8596
}
8697

98+
func (l *fullGPUDeviceSpecGenerator) device() (device.Device, error) {
99+
return l.devicelib.NewDeviceByUUID(l.uuid)
100+
}
101+
87102
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
88103
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
89-
device, err := l.newFullGPUDiscoverer(l.device)
104+
device, err := l.device()
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
deviceDiscoverer, err := l.newFullGPUDiscoverer(device)
90110
if err != nil {
91111
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
92112
}
93113

94-
editsForDevice, err := edits.FromDiscoverer(device)
114+
editsForDevice, err := edits.FromDiscoverer(deviceDiscoverer)
95115
if err != nil {
96116
return nil, fmt.Errorf("failed to create container edits for device: %v", err)
97117
}
@@ -100,7 +120,7 @@ func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, erro
100120
}
101121

102122
func (l *fullGPUDeviceSpecGenerator) getNames() ([]string, error) {
103-
return l.deviceNamers.GetDeviceNames(l.index, convert{l.device})
123+
return l.deviceNamers.GetDeviceNames(l.index, l)
104124
}
105125

106126
// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.

pkg/nvcdi/lib-nvml.go

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,22 @@ func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (DeviceSpecGenera
7272
identifiers = append(identifiers, device.Identifier(id))
7373
}
7474

75-
devices, err := l.getNVMLDevicesByID(identifiers...)
75+
uuids, err := l.normalizeDeviceIDs(identifiers...)
7676
if err != nil {
7777
return nil, err
7878
}
7979

8080
var DeviceSpecGenerators DeviceSpecGenerators
81-
for i, device := range devices {
82-
editor, err := l.newDeviceSpecGeneratorFromNVMLDevice(ids[i], device)
81+
for _, uuid := range uuids {
82+
device, ret := l.nvmllib.DeviceGetHandleByUUID(string(uuid))
83+
if ret != nvml.SUCCESS {
84+
return nil, fmt.Errorf("failed to get device handle from UUID: %v", ret)
85+
}
86+
generator, err := l.newDeviceSpecGeneratorFromNVMLDevice(string(uuid), device)
8387
if err != nil {
8488
return nil, err
8589
}
86-
DeviceSpecGenerators = append(DeviceSpecGenerators, editor)
90+
DeviceSpecGenerators = append(DeviceSpecGenerators, generator)
8791
}
8892

8993
return DeviceSpecGenerators, nil
@@ -92,7 +96,7 @@ func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (DeviceSpecGenera
9296
func (l *nvmllib) newDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
9397
isMig, ret := nvmlDevice.IsMigDeviceHandle()
9498
if ret != nvml.SUCCESS {
95-
return nil, ret
99+
return nil, fmt.Errorf("%v", ret)
96100
}
97101
if isMig {
98102
return l.newMIGDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice)
@@ -114,33 +118,23 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
114118
if isMigEnabled {
115119
return nil
116120
}
117-
e := &fullGPUDeviceSpecGenerator{
118-
nvmllib: l,
119-
id: fmt.Sprintf("%d", i),
120-
index: i,
121-
device: d,
121+
fullGPU, err := l.newFullGPUDeviceSpecGeneratorFromDevice(i, d)
122+
if err != nil {
123+
return err
122124
}
123-
124-
DeviceSpecGenerators = append(DeviceSpecGenerators, e)
125+
DeviceSpecGenerators = append(DeviceSpecGenerators, fullGPU)
125126
return nil
126127
})
127128
if err != nil {
128129
return nil, fmt.Errorf("failed to get full GPU device editors: %w", err)
129130
}
130131

131132
err = l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
132-
parentGenerator := &fullGPUDeviceSpecGenerator{
133-
nvmllib: l,
134-
index: i,
135-
id: fmt.Sprintf("%d:%d", i, j),
136-
device: d,
137-
}
138-
migGenerator := &migDeviceSpecGenerator{
139-
fullGPUDeviceSpecGenerator: parentGenerator,
140-
migIndex: j,
141-
migDevice: mig,
142-
}
143-
DeviceSpecGenerators = append(DeviceSpecGenerators, parentGenerator, migGenerator)
133+
migDevice, err := l.newMIGDeviceSpecGeneratorFromDevice(i, d, j, mig)
134+
if err != nil {
135+
return err
136+
}
137+
DeviceSpecGenerators = append(DeviceSpecGenerators, migDevice)
144138
return nil
145139
})
146140
if err != nil {
@@ -151,50 +145,68 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
151145
}
152146

153147
// TODO: move this to go-nvlib?
154-
func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) {
155-
var devices []nvml.Device
148+
// normalizeDeviceID returns the UUIDs of the devices specified by the identifier.
149+
func (l *nvmllib) normalizeDeviceIDs(identifiers ...device.Identifier) ([]device.Identifier, error) {
150+
var uuids []device.Identifier
156151
for _, id := range identifiers {
157-
dev, err := l.getNVMLDeviceByID(id)
158-
if err != nvml.SUCCESS {
159-
return nil, fmt.Errorf("failed to get NVML device handle for identifier %q: %w", id, err)
152+
uuid, err := l.normalizeDeviceID(id)
153+
if err != nil {
154+
return nil, err
160155
}
161-
devices = append(devices, dev)
156+
uuids = append(uuids, uuid)
162157
}
163-
return devices, nil
158+
return uuids, nil
164159
}
165160

166-
func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
161+
func (l *nvmllib) normalizeDeviceID(id device.Identifier) (device.Identifier, error) {
167162
var err error
168163

169164
if id.IsUUID() {
170-
return l.nvmllib.DeviceGetHandleByUUID(string(id))
165+
return id, nil
171166
}
172167

173168
if id.IsGpuIndex() {
174-
if idx, err := strconv.Atoi(string(id)); err == nil {
175-
return l.nvmllib.DeviceGetHandleByIndex(idx)
169+
idx, err := strconv.Atoi(string(id))
170+
if err != nil {
171+
return "", fmt.Errorf("failed to convert device index to an int: %w", err)
172+
}
173+
dev, ret := l.nvmllib.DeviceGetHandleByIndex(idx)
174+
if ret != nvml.SUCCESS {
175+
return "", fmt.Errorf("failed to get device handle from index: %v", ret)
176+
}
177+
uuid, ret := dev.GetUUID()
178+
if ret != nvml.SUCCESS {
179+
return "", fmt.Errorf("failed to get device UUID: %v", ret)
176180
}
177-
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
181+
return device.Identifier(uuid), nil
178182
}
179183

180184
if id.IsMigIndex() {
181185
var gpuIdx, migIdx int
182186
var parent nvml.Device
183187
split := strings.SplitN(string(id), ":", 2)
184188
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
185-
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
189+
return "", fmt.Errorf("failed to convert device index to an int: %w", err)
186190
}
187191
if migIdx, err = strconv.Atoi(split[1]); err != nil {
188-
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
192+
return "", fmt.Errorf("failed to convert device index to an int: %w", err)
189193
}
190194
parent, ret := l.nvmllib.DeviceGetHandleByIndex(gpuIdx)
191195
if ret != nvml.SUCCESS {
192-
return nil, fmt.Errorf("failed to get parent device handle: %v", ret)
196+
return "", fmt.Errorf("failed to get parent device handle: %v", ret)
197+
}
198+
mig, ret := parent.GetMigDeviceHandleByIndex(migIdx)
199+
if ret != nvml.SUCCESS {
200+
return "", fmt.Errorf("failed to get MIG handle by index: %v", ret)
201+
}
202+
uuid, ret := mig.GetUUID()
203+
if ret != nvml.SUCCESS {
204+
return "", fmt.Errorf("failed to get MIG UUID: %v", ret)
193205
}
194-
return parent.GetMigDeviceHandleByIndex(migIdx)
206+
return device.Identifier(uuid), nil
195207
}
196208

197-
return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id)
209+
return "", fmt.Errorf("identifier is not a valid UUID or index: %q", id)
198210
}
199211

200212
func (l *nvmllib) init() error {

pkg/nvcdi/lib-nvml_test.go

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestNvmllibGetDeviceSpecGeneratorsForIDs(t *testing.T) {
7171
if s == "GPU-12345678-1234-1234-1234-123456789abc" {
7272
return server.Devices[3], nvml.SUCCESS
7373
}
74-
return nil, nvml.ERROR_INVALID_ARGUMENT
74+
return server.Devices[0], nvml.SUCCESS
7575
}
7676
},
7777
expectedError: nil,
@@ -81,24 +81,35 @@ func TestNvmllibGetDeviceSpecGeneratorsForIDs(t *testing.T) {
8181
name: "MIG device index",
8282
ids: []string{"0:0"},
8383
setupMock: func(server *dgxa100.Server) {
84+
mig := &mocknvml.Device{
85+
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
86+
return true, nvml.SUCCESS
87+
},
88+
GetDeviceHandleFromMigDeviceHandleFunc: func() (nvml.Device, nvml.Return) {
89+
return server.Devices[0], nvml.SUCCESS
90+
},
91+
GetIndexFunc: func() (int, nvml.Return) {
92+
return 0, nvml.SUCCESS
93+
},
94+
GetUUIDFunc: func() (string, nvml.Return) {
95+
return "MIG-foo", nvml.SUCCESS
96+
},
97+
}
98+
8499
server.Devices[0].(*dgxa100.Device).GetMigDeviceHandleByIndexFunc = func(n int) (nvml.Device, nvml.Return) {
85100
if n != 0 {
86101
return nil, nvml.ERROR_INVALID_ARGUMENT
87102
}
88103

89-
mig := &mocknvml.Device{
90-
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
91-
return true, nvml.SUCCESS
92-
},
93-
GetDeviceHandleFromMigDeviceHandleFunc: func() (nvml.Device, nvml.Return) {
94-
return server.Devices[0], nvml.SUCCESS
95-
},
96-
GetIndexFunc: func() (int, nvml.Return) {
97-
return 0, nvml.SUCCESS
98-
},
99-
}
100104
return mig, nvml.SUCCESS
101105
}
106+
107+
server.DeviceGetHandleByUUIDFunc = func(s string) (nvml.Device, nvml.Return) {
108+
if s == "MIG-foo" {
109+
return mig, nvml.SUCCESS
110+
}
111+
return nil, nvml.ERROR_INVALID_ARGUMENT
112+
}
102113
},
103114
expectedError: nil,
104115
expectedLength: 1,
@@ -139,5 +150,8 @@ func mockOverrides(server *dgxa100.Server) {
139150
(d.(*dgxa100.Device)).GetIndexFunc = func() (int, nvml.Return) {
140151
return i, nvml.SUCCESS
141152
}
153+
(d.(*dgxa100.Device)).GetUUIDFunc = func() (string, nvml.Return) {
154+
return d.(*dgxa100.Device).UUID, nvml.SUCCESS
155+
}
142156
}
143157
}

0 commit comments

Comments
 (0)