Skip to content

Commit 57ef289

Browse files
committed
Handle multiple GPUs in CDI spec generation from CSV
This change allows CDI specs to be generated for multiple devices when using CSV mode. This can be used in cases where a Tegra-based system consists of an iGPU and dGPU. This behavior can be opted out of using the disable-multiple-csv-devices feature flag. This can be specified by adding the --feaure-flags=disable-multiple-csv-devices command line option to the nvidia-ctk cdi generate command or to the automatic CDI spec generation by adding NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS=disable-multiple-csv-devices to the /etc/nvidia-container-toolkit/nvidia-cdi-refresh.env file. Signed-off-by: Evan Lezar <[email protected]>
1 parent 37a16a7 commit 57ef289

File tree

4 files changed

+268
-12
lines changed

4 files changed

+268
-12
lines changed

internal/platform-support/tegra/mount_specs.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717

1818
package tegra
1919

20-
import "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
20+
import (
21+
"path/filepath"
22+
"strconv"
23+
"strings"
24+
25+
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
26+
)
2127

2228
// A MountSpecPathsByTyper provides a function to return mount specs paths by
2329
// mount type.
@@ -96,6 +102,53 @@ func (m filterMountSpecs) MountSpecPathsByType() MountSpecPathsByType {
96102
return ms
97103
}
98104

105+
type stripDeviceNodes struct {
106+
from MountSpecPathsByTyper
107+
}
108+
109+
// WithoutRegularDeviceNodes creates a MountSpecPathsByTyper which removes
110+
// regular `/dev/nvidia[0-9]+` device nodes from the source.
111+
func WithoutRegularDeviceNodes(from MountSpecPathsByTyper) MountSpecPathsByTyper {
112+
return &stripDeviceNodes{from}
113+
}
114+
115+
// MountSpecPathsByType returns the source mount specs with regular nvidia
116+
// device nodes removed from the source.
117+
func (d *stripDeviceNodes) MountSpecPathsByType() MountSpecPathsByType {
118+
ms := d.from.MountSpecPathsByType()
119+
if len(ms) == 0 {
120+
return ms
121+
}
122+
123+
filtered := d.Apply(ms[csv.MountSpecDev]...)
124+
ms[csv.MountSpecDev] = filtered
125+
126+
return ms
127+
}
128+
129+
func (d *stripDeviceNodes) Apply(input ...string) []string {
130+
var filtered []string
131+
for _, name := range input {
132+
if d.Match(name) {
133+
continue
134+
}
135+
filtered = append(filtered, name)
136+
}
137+
return filtered
138+
}
139+
140+
// Match returns true if name is a REGULAR NVIDIA GPU device node.
141+
func (d *stripDeviceNodes) Match(name string) bool {
142+
pattern := "/dev/nvidia*"
143+
if match, _ := filepath.Match(pattern, name); !match {
144+
return false
145+
}
146+
suffix := strings.TrimPrefix(name, "/dev/nvidia")
147+
// Check whether path has the form /dev/nvidia%d
148+
_, err := strconv.Atoi(suffix)
149+
return err == nil
150+
}
151+
99152
// DeviceNodes creates a set of MountSpecPaths for the specified device nodes.
100153
// These have the MoutSpecDev type.
101154
func DeviceNodes(dn ...string) MountSpecPathsByTyper {

pkg/nvcdi/api.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,8 @@ const (
8888
// FeatureEnableCoherentAnnotations enables the addition of annotations
8989
// coherent or non-coherent devices.
9090
FeatureEnableCoherentAnnotations = FeatureFlag("enable-coherent-annotations")
91+
92+
// FeatureDisableMultipleCSVDevices disables the handling of multiple devices
93+
// in CSV mode.
94+
FeatureDisableMultipleCSVDevices = FeatureFlag("disable-multiple-csv-devices")
9195
)

pkg/nvcdi/lib-csv.go

Lines changed: 210 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,49 @@ package nvcdi
1818

1919
import (
2020
"fmt"
21+
"slices"
22+
"strconv"
23+
"strings"
2124

2225
"tags.cncf.io/container-device-interface/pkg/cdi"
2326
"tags.cncf.io/container-device-interface/specs-go"
2427

28+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
29+
"github.com/NVIDIA/go-nvml/pkg/nvml"
30+
"github.com/google/uuid"
31+
2532
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2633
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2734
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
2835
)
2936

3037
type csvlib nvcdilib
3138

39+
type mixedcsvlib nvcdilib
40+
3241
var _ deviceSpecGeneratorFactory = (*csvlib)(nil)
3342

43+
// DeviceSpecGenerators creates a set of generators for the specified set of
44+
// devices.
45+
// If NVML is not available or the disable-multiple-csv-devices feature flag is
46+
// enabled, a single device is assumed.
3447
func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
48+
if l.featureFlags[FeatureDisableMultipleCSVDevices] {
49+
return l.purecsvDeviceSpecGenerators(ids...)
50+
}
51+
hasNVML, _ := l.infolib.HasNvml()
52+
if !hasNVML {
53+
return l.purecsvDeviceSpecGenerators(ids...)
54+
}
55+
mixed, err := l.mixedDeviceSpecGenerators(ids...)
56+
if err != nil {
57+
l.logger.Warningf("Failed to create mixed CSV spec generator; falling back to pure CSV implementation: %v", err)
58+
return l.purecsvDeviceSpecGenerators(ids...)
59+
}
60+
return mixed, nil
61+
}
62+
63+
func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
3564
for _, id := range ids {
3665
switch id {
3766
case "all":
@@ -40,12 +69,41 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
4069
return nil, fmt.Errorf("unsupported device id: %v", id)
4170
}
4271
}
72+
g := &csvDeviceGenerator{
73+
csvlib: l,
74+
index: 0,
75+
uuid: "",
76+
}
77+
return g, nil
78+
}
79+
80+
func (l *csvlib) mixedDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
81+
return (*mixedcsvlib)(l).DeviceSpecGenerators(ids...)
82+
}
4383

44-
return l, nil
84+
// A csvDeviceGenerator generates CDI specs for a device based on a set of
85+
// platform-specific CSV files.
86+
type csvDeviceGenerator struct {
87+
*csvlib
88+
index int
89+
uuid string
90+
onlyDeviceNodes []string
91+
additionalDeviceNodes []string
92+
}
93+
94+
func (l *csvDeviceGenerator) GetUUID() (string, error) {
95+
return l.uuid, nil
4596
}
4697

4798
// GetDeviceSpecs returns the CDI device specs for a single device.
48-
func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
99+
func (l *csvDeviceGenerator) GetDeviceSpecs() ([]specs.Device, error) {
100+
mountSpecs := tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...)
101+
if len(l.onlyDeviceNodes) > 0 {
102+
mountSpecs = tegra.Merge(
103+
tegra.WithoutRegularDeviceNodes(mountSpecs),
104+
tegra.DeviceNodes(l.onlyDeviceNodes...),
105+
)
106+
}
49107
d, err := tegra.New(
50108
tegra.WithLogger(l.logger),
51109
tegra.WithDriverRoot(l.driverRoot),
@@ -55,8 +113,13 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
55113
tegra.WithLibrarySearchPaths(l.librarySearchPaths...),
56114
tegra.WithMountSpecsByPath(
57115
tegra.Filter(
58-
tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...),
59-
tegra.Symlinks(l.csvIgnorePatterns...),
116+
tegra.Merge(
117+
mountSpecs,
118+
tegra.DeviceNodes(l.additionalDeviceNodes...),
119+
),
120+
tegra.Merge(
121+
tegra.Symlinks(l.csvIgnorePatterns...),
122+
),
60123
),
61124
),
62125
)
@@ -68,7 +131,7 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
68131
return nil, fmt.Errorf("failed to create container edits for CSV files: %v", err)
69132
}
70133

71-
names, err := l.deviceNamers.GetDeviceNames(0, uuidIgnored{})
134+
names, err := l.deviceNamers.GetDeviceNames(l.index, l)
72135
if err != nil {
73136
return nil, fmt.Errorf("failed to get device name: %v", err)
74137
}
@@ -88,3 +151,145 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
88151
func (l *csvlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
89152
return edits.FromDiscoverer(discover.None{})
90153
}
154+
155+
func (l *mixedcsvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
156+
asNvmlLib := (*nvmllib)(l)
157+
err := asNvmlLib.init()
158+
if err != nil {
159+
return nil, fmt.Errorf("failed to initialize nvml: %w", err)
160+
}
161+
defer asNvmlLib.tryShutdown()
162+
163+
if slices.Contains(ids, "all") {
164+
ids, err = l.getAllDeviceIndices()
165+
if err != nil {
166+
return nil, fmt.Errorf("failed to get device indices: %w", err)
167+
}
168+
}
169+
170+
var DeviceSpecGenerators DeviceSpecGenerators
171+
for _, id := range ids {
172+
generator, err := l.deviceSpecGeneratorForId(device.Identifier(id))
173+
if err != nil {
174+
return nil, fmt.Errorf("failed to create device spec generator for device %q: %w", id, err)
175+
}
176+
DeviceSpecGenerators = append(DeviceSpecGenerators, generator)
177+
}
178+
179+
return DeviceSpecGenerators, nil
180+
}
181+
182+
func (l *mixedcsvlib) getAllDeviceIndices() ([]string, error) {
183+
numDevices, ret := l.nvmllib.DeviceGetCount()
184+
if ret != nvml.SUCCESS {
185+
return nil, fmt.Errorf("faled to get device count: %v", ret)
186+
}
187+
188+
var allIndices []string
189+
for index := range numDevices {
190+
allIndices = append(allIndices, fmt.Sprintf("%d", index))
191+
}
192+
return allIndices, nil
193+
}
194+
195+
func (l *mixedcsvlib) deviceSpecGeneratorForId(id device.Identifier) (DeviceSpecGenerator, error) {
196+
switch {
197+
case id.IsGpuUUID(), isIntegratedGPUID(id):
198+
uuid := string(id)
199+
device, ret := l.nvmllib.DeviceGetHandleByUUID(uuid)
200+
if ret != nvml.SUCCESS {
201+
return nil, fmt.Errorf("failed to get device handle from UUID %q: %v", uuid, ret)
202+
}
203+
index, ret := device.GetIndex()
204+
if ret != nvml.SUCCESS {
205+
return nil, fmt.Errorf("failed to get device index: %v", ret)
206+
}
207+
return l.csvDeviceSpecGenerator(index, uuid, device)
208+
case id.IsGpuIndex():
209+
index, err := strconv.Atoi(string(id))
210+
if err != nil {
211+
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
212+
}
213+
device, ret := l.nvmllib.DeviceGetHandleByIndex(index)
214+
if ret != nvml.SUCCESS {
215+
return nil, fmt.Errorf("failed to get device handle from index: %v", ret)
216+
}
217+
uuid, ret := device.GetUUID()
218+
if ret != nvml.SUCCESS {
219+
return nil, fmt.Errorf("failed to get UUID: %v", ret)
220+
}
221+
return l.csvDeviceSpecGenerator(index, uuid, device)
222+
case id.IsMigUUID():
223+
fallthrough
224+
case id.IsMigIndex():
225+
return nil, fmt.Errorf("generating a CDI spec for MIG id %q is not supported in CSV mode", id)
226+
}
227+
return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id)
228+
}
229+
230+
func (l *mixedcsvlib) csvDeviceSpecGenerator(index int, uuid string, device nvml.Device) (DeviceSpecGenerator, error) {
231+
var additionalDeviceNodes []string
232+
isIntegrated, err := isIntegratedGPU(device)
233+
if err != nil {
234+
return nil, fmt.Errorf("is-integrated check failed for device (index=%v,uuid=%v)", index, uuid)
235+
}
236+
if !isIntegrated {
237+
additionalDeviceNodes = []string{
238+
"/dev/nvidia-uvm",
239+
"/dev/nvidia-uvm-tools",
240+
}
241+
}
242+
g := &csvDeviceGenerator{
243+
csvlib: (*csvlib)(l),
244+
index: index,
245+
uuid: uuid,
246+
onlyDeviceNodes: []string{fmt.Sprintf("/dev/nvidia%d", index)},
247+
additionalDeviceNodes: additionalDeviceNodes,
248+
}
249+
return g, nil
250+
}
251+
252+
func isIntegratedGPUID(id device.Identifier) bool {
253+
_, err := uuid.Parse(string(id))
254+
return err == nil
255+
}
256+
257+
// isIntegratedGPU checks whether the specified device is an integrated GPU.
258+
// As a proxy we check the PCI Bus if for thes
259+
// TODO: This should be replaced by an explicit NVML call once available.
260+
func isIntegratedGPU(d nvml.Device) (bool, error) {
261+
pciInfo, ret := d.GetPciInfo()
262+
if ret == nvml.ERROR_NOT_SUPPORTED {
263+
name, ret := d.GetName()
264+
if ret != nvml.SUCCESS {
265+
return false, fmt.Errorf("failed to get device name: %v", ret)
266+
}
267+
return isIntegratedGPUName(name), nil
268+
}
269+
if ret != nvml.SUCCESS {
270+
return false, fmt.Errorf("failed to get PCI info: %v", ret)
271+
}
272+
273+
if pciInfo.Domain != 0 {
274+
return false, nil
275+
}
276+
if pciInfo.Bus != 1 {
277+
return false, nil
278+
}
279+
return pciInfo.Device == 0, nil
280+
}
281+
282+
// isIntegratedGPUName returns true if the specified device name is associated
283+
// with a known iGPU.
284+
//
285+
// TODO: Consider making go-nvlib/pkg/nvlib/info/isIntegratedGPUName public
286+
// instead.
287+
func isIntegratedGPUName(name string) bool {
288+
if strings.Contains(name, "(nvgpu)") {
289+
return true
290+
}
291+
if strings.Contains(name, "NVIDIA Thor") {
292+
return true
293+
}
294+
return false
295+
}

pkg/nvcdi/namer.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,6 @@ type convert struct {
105105
nvmlUUIDer
106106
}
107107

108-
type uuidIgnored struct{}
109-
110-
func (m uuidIgnored) GetUUID() (string, error) {
111-
return "", nil
112-
}
113-
114108
type uuidUnsupported struct{}
115109

116110
func (m convert) GetUUID() (string, error) {

0 commit comments

Comments
 (0)