Skip to content

Commit 6703673

Browse files
authored
Merge pull request #1230 from elezar/fix-gdrcopy-in-jit-cdi
Add support for gated modifications jit-cdi mode
2 parents 4507575 + 3fcc351 commit 6703673

File tree

6 files changed

+118
-103
lines changed

6 files changed

+118
-103
lines changed

internal/modifier/cdi.go

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
6262
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")
6363
}
6464
if len(automaticDevices) > 0 {
65+
automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...)
6566
automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices)
6667
if err == nil {
6768
return automaticModifier, nil
@@ -111,6 +112,29 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string {
111112
return devices
112113
}
113114

115+
type gatedDevices image.CUDA
116+
117+
// DeviceRequests returns a list of devices that are required for gated devices.
118+
func (g gatedDevices) DeviceRequests() []string {
119+
i := (image.CUDA)(g)
120+
121+
var devices []string
122+
if i.Getenv("NVIDIA_GDS") == "enabled" {
123+
devices = append(devices, "mode=gds")
124+
}
125+
if i.Getenv("NVIDIA_MOFED") == "enabled" {
126+
devices = append(devices, "mode=mofed")
127+
}
128+
if i.Getenv("NVIDIA_GDRCOPY") == "enabled" {
129+
devices = append(devices, "mode=gdrcopy")
130+
}
131+
if i.Getenv("NVIDIA_NVSWITCH") == "enabled" {
132+
devices = append(devices, "mode=nvswitch")
133+
}
134+
135+
return devices
136+
}
137+
114138
// filterAutomaticDevices searches for "automatic" device names in the input slice.
115139
// "Automatic" devices are a well-defined list of CDI device names which, when requested,
116140
// trigger the generation of a CDI spec at runtime. This removes the need to generate a
@@ -129,35 +153,48 @@ func filterAutomaticDevices(devices []string) []string {
129153
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
130154
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
131155

132-
var identifiers []string
156+
perModeIdentifiers := make(map[string][]string)
157+
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
158+
modes := []string{"auto"}
133159
for _, device := range devices {
134-
identifiers = append(identifiers, strings.TrimPrefix(device, automaticDevicePrefix))
160+
if strings.HasPrefix(device, "mode=") {
161+
modes = append(modes, strings.TrimPrefix(device, "mode="))
162+
continue
163+
}
164+
perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix))
135165
}
136166

137-
cdilib, err := nvcdi.New(
138-
nvcdi.WithLogger(logger),
139-
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
140-
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
141-
nvcdi.WithVendor(automaticDeviceVendor),
142-
nvcdi.WithClass(automaticDeviceClass),
143-
)
144-
if err != nil {
145-
return nil, fmt.Errorf("failed to construct CDI library: %w", err)
146-
}
167+
var modifiers oci.SpecModifiers
168+
for _, mode := range modes {
169+
cdilib, err := nvcdi.New(
170+
nvcdi.WithLogger(logger),
171+
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
172+
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
173+
nvcdi.WithVendor(automaticDeviceVendor),
174+
nvcdi.WithClass(perModeDeviceClass[mode]),
175+
nvcdi.WithMode(mode),
176+
)
177+
if err != nil {
178+
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)
179+
}
147180

148-
spec, err := cdilib.GetSpec(identifiers...)
149-
if err != nil {
150-
return nil, fmt.Errorf("failed to generate CDI spec: %w", err)
151-
}
152-
cdiDeviceRequestor, err := cdi.New(
153-
cdi.WithLogger(logger),
154-
cdi.WithSpec(spec.Raw()),
155-
)
156-
if err != nil {
157-
return nil, fmt.Errorf("failed to construct CDI modifier: %w", err)
181+
spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...)
182+
if err != nil {
183+
return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err)
184+
}
185+
186+
cdiDeviceRequestor, err := cdi.New(
187+
cdi.WithLogger(logger),
188+
cdi.WithSpec(spec.Raw()),
189+
)
190+
if err != nil {
191+
return nil, fmt.Errorf("failed to construct CDI modifier for mode %q: %w", mode, err)
192+
}
193+
194+
modifiers = append(modifiers, cdiDeviceRequestor)
158195
}
159196

160-
return cdiDeviceRequestor, nil
197+
return modifiers, nil
161198
}
162199

163200
type deduplicatedDeviceRequestor struct {

internal/oci/spec.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ type SpecModifier interface {
3131
Modify(*specs.Spec) error
3232
}
3333

34+
// SpecModifiers is a collection of OCI Spec modifiers that can be treated as a
35+
// single modifier.
36+
type SpecModifiers []SpecModifier
37+
38+
var _ SpecModifier = (SpecModifiers)(nil)
39+
3440
// Spec defines the operations to be performed on an OCI specification
3541
//
3642
//go:generate moq -rm -fmt=goimports -stub -out spec_mock.go . Spec
@@ -57,3 +63,16 @@ func NewSpec(logger logger.Interface, args []string) (Spec, error) {
5763

5864
return ociSpec, nil
5965
}
66+
67+
// Modify a spec based on a collection of modifiers.
68+
func (ms SpecModifiers) Modify(s *specs.Spec) error {
69+
for _, m := range ms {
70+
if m == nil {
71+
continue
72+
}
73+
if err := m.Modify(s); err != nil {
74+
return err
75+
}
76+
}
77+
return nil
78+
}

pkg/nvcdi/mofed.go renamed to pkg/nvcdi/gated.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,23 @@ import (
2626
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2727
)
2828

29-
type mofedlib nvcdilib
29+
type gatedlib nvcdilib
3030

31-
var _ deviceSpecGeneratorFactory = (*mofedlib)(nil)
31+
var _ deviceSpecGeneratorFactory = (*gatedlib)(nil)
3232

33-
func (l *mofedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
33+
func (l *gatedlib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
3434
return l, nil
3535
}
3636

3737
// GetDeviceSpecs returns the CDI device specs for a single all device.
38-
func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) {
39-
discoverer, err := discover.NewMOFEDDiscoverer(l.logger, l.driverRoot)
38+
func (l *gatedlib) GetDeviceSpecs() ([]specs.Device, error) {
39+
discoverer, err := l.getModeDiscoverer()
4040
if err != nil {
41-
return nil, fmt.Errorf("failed to create MOFED discoverer: %v", err)
41+
return nil, fmt.Errorf("failed to create discoverer for mode %q: %w", l.mode, err)
4242
}
4343
edits, err := edits.FromDiscoverer(discoverer)
4444
if err != nil {
45-
return nil, fmt.Errorf("failed to create container edits for MOFED devices: %v", err)
45+
return nil, fmt.Errorf("failed to create container edits: %w", err)
4646
}
4747

4848
deviceSpec := specs.Device{
@@ -53,7 +53,22 @@ func (l *mofedlib) GetDeviceSpecs() ([]specs.Device, error) {
5353
return []specs.Device{deviceSpec}, nil
5454
}
5555

56+
func (l *gatedlib) getModeDiscoverer() (discover.Discover, error) {
57+
switch l.mode {
58+
case ModeGdrcopy:
59+
return discover.NewGDRCopyDiscoverer(l.logger, l.devRoot)
60+
case ModeGds:
61+
return discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot)
62+
case ModeMofed:
63+
return discover.NewMOFEDDiscoverer(l.logger, l.driverRoot)
64+
case ModeNvswitch:
65+
return discover.NewNvSwitchDiscoverer(l.logger, l.devRoot)
66+
default:
67+
return nil, fmt.Errorf("unrecognized mode")
68+
}
69+
}
70+
5671
// GetCommonEdits generates a CDI specification that can be used for ANY devices
57-
func (l *mofedlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
72+
func (l *gatedlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
5873
return edits.FromDiscoverer(discover.None{})
5974
}

pkg/nvcdi/gds.go

Lines changed: 0 additions & 59 deletions
This file was deleted.

pkg/nvcdi/lib.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,11 @@ func New(opts ...Option) (Interface, error) {
129129
factory = (*nvmllib)(l)
130130
case ModeWsl:
131131
factory = (*wsllib)(l)
132-
case ModeGds:
132+
case ModeGdrcopy, ModeGds, ModeMofed:
133133
if l.class == "" {
134-
l.class = "gds"
134+
l.class = string(l.mode)
135135
}
136-
factory = (*gdslib)(l)
137-
case ModeMofed:
138-
if l.class == "" {
139-
l.class = "mofed"
140-
}
141-
factory = (*mofedlib)(l)
136+
factory = (*gatedlib)(l)
142137
case ModeImex:
143138
if l.class == "" {
144139
l.class = classImexChannel

pkg/nvcdi/mode.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@ const (
3333
ModeWsl = Mode("wsl")
3434
// ModeManagement configures the CDI spec generator to generate a management spec.
3535
ModeManagement = Mode("management")
36+
// ModeGdrcopy configures the CDI spec generator to generate a GDR Copy spec.
37+
ModeGdrcopy = Mode("gdrcopy")
3638
// ModeGds configures the CDI spec generator to generate a GDS spec.
3739
ModeGds = Mode("gds")
3840
// ModeMofed configures the CDI spec generator to generate a MOFED spec.
3941
ModeMofed = Mode("mofed")
4042
// ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV
4143
// mountspec files.
4244
ModeCSV = Mode("csv")
43-
// ModeImex configures the CDI spec generated to generate a spec for the available IMEX channels.
45+
// ModeImex configures the CDI spec generator to generate a spec for the available IMEX channels.
4446
ModeImex = Mode("imex")
47+
// ModeNvswitch configures the CDI spec generator to generate a spec for the available nvswitch devices.
48+
ModeNvswitch = Mode("nvswitch")
4549
)
4650

4751
type modeConstraint interface {
@@ -60,12 +64,15 @@ func getModes() modes {
6064
validModesOnce.Do(func() {
6165
all := []Mode{
6266
ModeAuto,
63-
ModeNvml,
64-
ModeWsl,
65-
ModeManagement,
67+
ModeCSV,
68+
ModeGdrcopy,
6669
ModeGds,
70+
ModeImex,
71+
ModeManagement,
6772
ModeMofed,
68-
ModeCSV,
73+
ModeNvml,
74+
ModeNvswitch,
75+
ModeWsl,
6976
}
7077
lookup := make(map[Mode]bool)
7178

@@ -103,6 +110,7 @@ func (l *nvcdilib) resolveMode() (rmode Mode) {
103110
}
104111
defer func() {
105112
l.logger.Infof("Auto-detected mode as '%v'", rmode)
113+
l.mode = rmode
106114
}()
107115

108116
platform := l.infolib.ResolvePlatform()

0 commit comments

Comments
 (0)