Skip to content

Commit 81cdd4e

Browse files
authored
Merge pull request #1194 from elezar/fix-cdi-spec-generation
[no-relnote] Fix CDI spec generation
2 parents ac49ea2 + 62794e9 commit 81cdd4e

File tree

4 files changed

+163
-83
lines changed

4 files changed

+163
-83
lines changed

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,48 @@ import (
3434
// single full GPU.
3535
type fullGPUDeviceSpecGenerator struct {
3636
*nvmllib
37-
id string
38-
index int
39-
device device.Device
37+
uuid string
38+
index int
4039
}
4140

4241
var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
4342

44-
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
45-
device, err := l.devicelib.NewDevice(nvmlDevice)
46-
if err != nil {
47-
return nil, err
43+
func (l *fullGPUDeviceSpecGenerator) GetUUID() (string, error) {
44+
return l.uuid, nil
45+
}
46+
47+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromDevice(index int, d device.Device) (*fullGPUDeviceSpecGenerator, error) {
48+
uuid, ret := d.GetUUID()
49+
if ret != nvml.SUCCESS {
50+
return nil, fmt.Errorf("failed to get device UUID: %v", ret)
51+
}
52+
e := &fullGPUDeviceSpecGenerator{
53+
nvmllib: l,
54+
uuid: uuid,
55+
index: index,
4856
}
4957

58+
return e, nil
59+
}
60+
61+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(uuid string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
5062
index, ret := nvmlDevice.GetIndex()
5163
if ret != nvml.SUCCESS {
5264
return nil, fmt.Errorf("failed to get device index: %v", ret)
5365
}
5466

5567
e := &fullGPUDeviceSpecGenerator{
5668
nvmllib: l,
57-
id: id,
69+
uuid: uuid,
5870
index: index,
59-
device: device,
6071
}
6172
return e, nil
6273
}
6374

6475
func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
6576
deviceEdits, err := l.getDeviceEdits()
6677
if err != nil {
67-
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", l.id, err)
78+
return nil, fmt.Errorf("failed to get CDI device edits: %w", err)
6879
}
6980

7081
names, err := l.getNames()
@@ -84,14 +95,23 @@ func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
8495
return deviceSpecs, nil
8596
}
8697

98+
func (l *fullGPUDeviceSpecGenerator) device() (device.Device, error) {
99+
return l.devicelib.NewDeviceByUUID(l.uuid)
100+
}
101+
87102
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
88103
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
89-
device, err := l.newFullGPUDiscoverer(l.device)
104+
device, err := l.device()
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
deviceDiscoverer, err := l.newFullGPUDiscoverer(device)
90110
if err != nil {
91111
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
92112
}
93113

94-
editsForDevice, err := edits.FromDiscoverer(device)
114+
editsForDevice, err := edits.FromDiscoverer(deviceDiscoverer)
95115
if err != nil {
96116
return nil, fmt.Errorf("failed to create container edits for device: %v", err)
97117
}
@@ -100,7 +120,7 @@ func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, erro
100120
}
101121

102122
func (l *fullGPUDeviceSpecGenerator) getNames() ([]string, error) {
103-
return l.deviceNamers.GetDeviceNames(l.index, convert{l.device})
123+
return l.deviceNamers.GetDeviceNames(l.index, l)
104124
}
105125

106126
// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.

pkg/nvcdi/lib-nvml.go

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,22 @@ func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (DeviceSpecGenera
7272
identifiers = append(identifiers, device.Identifier(id))
7373
}
7474

75-
devices, err := l.getNVMLDevicesByID(identifiers...)
75+
uuids, err := l.normalizeDeviceIDs(identifiers...)
7676
if err != nil {
7777
return nil, err
7878
}
7979

8080
var DeviceSpecGenerators DeviceSpecGenerators
81-
for i, device := range devices {
82-
editor, err := l.newDeviceSpecGeneratorFromNVMLDevice(ids[i], device)
81+
for _, uuid := range uuids {
82+
device, ret := l.nvmllib.DeviceGetHandleByUUID(string(uuid))
83+
if ret != nvml.SUCCESS {
84+
return nil, fmt.Errorf("failed to get device handle from UUID: %v", ret)
85+
}
86+
generator, err := l.newDeviceSpecGeneratorFromNVMLDevice(string(uuid), device)
8387
if err != nil {
8488
return nil, err
8589
}
86-
DeviceSpecGenerators = append(DeviceSpecGenerators, editor)
90+
DeviceSpecGenerators = append(DeviceSpecGenerators, generator)
8791
}
8892

8993
return DeviceSpecGenerators, nil
@@ -92,7 +96,7 @@ func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (DeviceSpecGenera
9296
func (l *nvmllib) newDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
9397
isMig, ret := nvmlDevice.IsMigDeviceHandle()
9498
if ret != nvml.SUCCESS {
95-
return nil, ret
99+
return nil, fmt.Errorf("%v", ret)
96100
}
97101
if isMig {
98102
return l.newMIGDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice)
@@ -114,33 +118,23 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
114118
if isMigEnabled {
115119
return nil
116120
}
117-
e := &fullGPUDeviceSpecGenerator{
118-
nvmllib: l,
119-
id: fmt.Sprintf("%d", i),
120-
index: i,
121-
device: d,
121+
fullGPU, err := l.newFullGPUDeviceSpecGeneratorFromDevice(i, d)
122+
if err != nil {
123+
return err
122124
}
123-
124-
DeviceSpecGenerators = append(DeviceSpecGenerators, e)
125+
DeviceSpecGenerators = append(DeviceSpecGenerators, fullGPU)
125126
return nil
126127
})
127128
if err != nil {
128129
return nil, fmt.Errorf("failed to get full GPU device editors: %w", err)
129130
}
130131

131132
err = l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
132-
parentGenerator := &fullGPUDeviceSpecGenerator{
133-
nvmllib: l,
134-
index: i,
135-
id: fmt.Sprintf("%d:%d", i, j),
136-
device: d,
137-
}
138-
migGenerator := &migDeviceSpecGenerator{
139-
fullGPUDeviceSpecGenerator: parentGenerator,
140-
migIndex: j,
141-
migDevice: mig,
142-
}
143-
DeviceSpecGenerators = append(DeviceSpecGenerators, parentGenerator, migGenerator)
133+
migDevice, err := l.newMIGDeviceSpecGeneratorFromDevice(i, d, j, mig)
134+
if err != nil {
135+
return err
136+
}
137+
DeviceSpecGenerators = append(DeviceSpecGenerators, migDevice)
144138
return nil
145139
})
146140
if err != nil {
@@ -151,50 +145,68 @@ func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() (DeviceSpecGenerator, e
151145
}
152146

153147
// TODO: move this to go-nvlib?
154-
func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) {
155-
var devices []nvml.Device
148+
// normalizeDeviceID returns the UUIDs of the devices specified by the identifier.
149+
func (l *nvmllib) normalizeDeviceIDs(identifiers ...device.Identifier) ([]device.Identifier, error) {
150+
var uuids []device.Identifier
156151
for _, id := range identifiers {
157-
dev, err := l.getNVMLDeviceByID(id)
158-
if err != nvml.SUCCESS {
159-
return nil, fmt.Errorf("failed to get NVML device handle for identifier %q: %w", id, err)
152+
uuid, err := l.normalizeDeviceID(id)
153+
if err != nil {
154+
return nil, err
160155
}
161-
devices = append(devices, dev)
156+
uuids = append(uuids, uuid)
162157
}
163-
return devices, nil
158+
return uuids, nil
164159
}
165160

166-
func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
161+
func (l *nvmllib) normalizeDeviceID(id device.Identifier) (device.Identifier, error) {
167162
var err error
168163

169164
if id.IsUUID() {
170-
return l.nvmllib.DeviceGetHandleByUUID(string(id))
165+
return id, nil
171166
}
172167

173168
if id.IsGpuIndex() {
174-
if idx, err := strconv.Atoi(string(id)); err == nil {
175-
return l.nvmllib.DeviceGetHandleByIndex(idx)
169+
idx, err := strconv.Atoi(string(id))
170+
if err != nil {
171+
return "", fmt.Errorf("failed to convert device index to an int: %w", err)
172+
}
173+
dev, ret := l.nvmllib.DeviceGetHandleByIndex(idx)
174+
if ret != nvml.SUCCESS {
175+
return "", fmt.Errorf("failed to get device handle from index: %v", ret)
176+
}
177+
uuid, ret := dev.GetUUID()
178+
if ret != nvml.SUCCESS {
179+
return "", fmt.Errorf("failed to get device UUID: %v", ret)
176180
}
177-
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
181+
return device.Identifier(uuid), nil
178182
}
179183

180184
if id.IsMigIndex() {
181185
var gpuIdx, migIdx int
182186
var parent nvml.Device
183187
split := strings.SplitN(string(id), ":", 2)
184188
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
185-
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
189+
return "", fmt.Errorf("failed to convert device index to an int: %w", err)
186190
}
187191
if migIdx, err = strconv.Atoi(split[1]); err != nil {
188-
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
192+
return "", fmt.Errorf("failed to convert device index to an int: %w", err)
189193
}
190194
parent, ret := l.nvmllib.DeviceGetHandleByIndex(gpuIdx)
191195
if ret != nvml.SUCCESS {
192-
return nil, fmt.Errorf("failed to get parent device handle: %v", ret)
196+
return "", fmt.Errorf("failed to get parent device handle: %v", ret)
197+
}
198+
mig, ret := parent.GetMigDeviceHandleByIndex(migIdx)
199+
if ret != nvml.SUCCESS {
200+
return "", fmt.Errorf("failed to get MIG handle by index: %v", ret)
201+
}
202+
uuid, ret := mig.GetUUID()
203+
if ret != nvml.SUCCESS {
204+
return "", fmt.Errorf("failed to get MIG UUID: %v", ret)
193205
}
194-
return parent.GetMigDeviceHandleByIndex(migIdx)
206+
return device.Identifier(uuid), nil
195207
}
196208

197-
return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id)
209+
return "", fmt.Errorf("identifier is not a valid UUID or index: %q", id)
198210
}
199211

200212
func (l *nvmllib) init() error {

pkg/nvcdi/lib-nvml_test.go

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestNvmllibGetDeviceSpecGeneratorsForIDs(t *testing.T) {
7171
if s == "GPU-12345678-1234-1234-1234-123456789abc" {
7272
return server.Devices[3], nvml.SUCCESS
7373
}
74-
return nil, nvml.ERROR_INVALID_ARGUMENT
74+
return server.Devices[0], nvml.SUCCESS
7575
}
7676
},
7777
expectedError: nil,
@@ -81,24 +81,35 @@ func TestNvmllibGetDeviceSpecGeneratorsForIDs(t *testing.T) {
8181
name: "MIG device index",
8282
ids: []string{"0:0"},
8383
setupMock: func(server *dgxa100.Server) {
84+
mig := &mocknvml.Device{
85+
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
86+
return true, nvml.SUCCESS
87+
},
88+
GetDeviceHandleFromMigDeviceHandleFunc: func() (nvml.Device, nvml.Return) {
89+
return server.Devices[0], nvml.SUCCESS
90+
},
91+
GetIndexFunc: func() (int, nvml.Return) {
92+
return 0, nvml.SUCCESS
93+
},
94+
GetUUIDFunc: func() (string, nvml.Return) {
95+
return "MIG-foo", nvml.SUCCESS
96+
},
97+
}
98+
8499
server.Devices[0].(*dgxa100.Device).GetMigDeviceHandleByIndexFunc = func(n int) (nvml.Device, nvml.Return) {
85100
if n != 0 {
86101
return nil, nvml.ERROR_INVALID_ARGUMENT
87102
}
88103

89-
mig := &mocknvml.Device{
90-
IsMigDeviceHandleFunc: func() (bool, nvml.Return) {
91-
return true, nvml.SUCCESS
92-
},
93-
GetDeviceHandleFromMigDeviceHandleFunc: func() (nvml.Device, nvml.Return) {
94-
return server.Devices[0], nvml.SUCCESS
95-
},
96-
GetIndexFunc: func() (int, nvml.Return) {
97-
return 0, nvml.SUCCESS
98-
},
99-
}
100104
return mig, nvml.SUCCESS
101105
}
106+
107+
server.DeviceGetHandleByUUIDFunc = func(s string) (nvml.Device, nvml.Return) {
108+
if s == "MIG-foo" {
109+
return mig, nvml.SUCCESS
110+
}
111+
return nil, nvml.ERROR_INVALID_ARGUMENT
112+
}
102113
},
103114
expectedError: nil,
104115
expectedLength: 1,
@@ -139,5 +150,8 @@ func mockOverrides(server *dgxa100.Server) {
139150
(d.(*dgxa100.Device)).GetIndexFunc = func() (int, nvml.Return) {
140151
return i, nvml.SUCCESS
141152
}
153+
(d.(*dgxa100.Device)).GetUUIDFunc = func() (string, nvml.Return) {
154+
return d.(*dgxa100.Device).UUID, nvml.SUCCESS
155+
}
142156
}
143157
}

0 commit comments

Comments
 (0)