Skip to content

Commit 086955b

Browse files
committed
Add support for gated modifications jit-cdi mode
This change ensures that NVIDIA_GDS=enabled, NVIDIA_MOFED=enabled, NVIDIA_GDRCOPY=enabled, and NVIDIA_NVSWITCH=enabled are honored by jit-cdi mode. Signed-off-by: Evan Lezar <[email protected]>
1 parent c3d0821 commit 086955b

File tree

2 files changed

+79
-23
lines changed

2 files changed

+79
-23
lines changed

internal/modifier/cdi.go

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
6262
return nil, fmt.Errorf("requesting a CDI device with vendor 'runtime.nvidia.com' is not supported when requesting other CDI devices")
6363
}
6464
if len(automaticDevices) > 0 {
65+
automaticDevices = append(automaticDevices, gatedDevices(image).DeviceRequests()...)
6566
automaticModifier, err := newAutomaticCDISpecModifier(logger, cfg, automaticDevices)
6667
if err == nil {
6768
return automaticModifier, nil
@@ -111,6 +112,29 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string {
111112
return devices
112113
}
113114

115+
type gatedDevices image.CUDA
116+
117+
// DeviceRequests returns a list of devices that are required for gated devices.
118+
func (g gatedDevices) DeviceRequests() []string {
119+
i := (image.CUDA)(g)
120+
121+
var devices []string
122+
if i.Getenv("NVIDIA_GDS") == "enabled" {
123+
devices = append(devices, "mode=gds")
124+
}
125+
if i.Getenv("NVIDIA_MOFED") == "enabled" {
126+
devices = append(devices, "mode=mofed")
127+
}
128+
if i.Getenv("NVIDIA_GDRCOPY") == "enabled" {
129+
devices = append(devices, "mode=gdrcopy")
130+
}
131+
if i.Getenv("NVIDIA_NVSWITCH") == "enabled" {
132+
devices = append(devices, "mode=nvswitch")
133+
}
134+
135+
return devices
136+
}
137+
114138
// filterAutomaticDevices searches for "automatic" device names in the input slice.
115139
// "Automatic" devices are a well-defined list of CDI device names which, when requested,
116140
// trigger the generation of a CDI spec at runtime. This removes the need to generate a
@@ -129,35 +153,48 @@ func filterAutomaticDevices(devices []string) []string {
129153
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
130154
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
131155

132-
var identifiers []string
156+
perModeIdentifiers := make(map[string][]string)
157+
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
158+
modes := []string{"auto"}
133159
for _, device := range devices {
134-
identifiers = append(identifiers, strings.TrimPrefix(device, automaticDevicePrefix))
160+
if strings.HasPrefix(device, "mode=") {
161+
modes = append(modes, strings.TrimPrefix(device, "mode="))
162+
continue
163+
}
164+
perModeIdentifiers["auto"] = append(perModeIdentifiers["auto"], strings.TrimPrefix(device, automaticDevicePrefix))
135165
}
136166

137-
cdilib, err := nvcdi.New(
138-
nvcdi.WithLogger(logger),
139-
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
140-
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
141-
nvcdi.WithVendor(automaticDeviceVendor),
142-
nvcdi.WithClass(automaticDeviceClass),
143-
)
144-
if err != nil {
145-
return nil, fmt.Errorf("failed to construct CDI library: %w", err)
146-
}
167+
var modifiers oci.SpecModifiers
168+
for _, mode := range modes {
169+
cdilib, err := nvcdi.New(
170+
nvcdi.WithLogger(logger),
171+
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
172+
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
173+
nvcdi.WithVendor(automaticDeviceVendor),
174+
nvcdi.WithClass(perModeDeviceClass[mode]),
175+
nvcdi.WithMode(mode),
176+
)
177+
if err != nil {
178+
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)
179+
}
147180

148-
spec, err := cdilib.GetSpec(identifiers...)
149-
if err != nil {
150-
return nil, fmt.Errorf("failed to generate CDI spec: %w", err)
151-
}
152-
cdiDeviceRequestor, err := cdi.New(
153-
cdi.WithLogger(logger),
154-
cdi.WithSpec(spec.Raw()),
155-
)
156-
if err != nil {
157-
return nil, fmt.Errorf("failed to construct CDI modifier: %w", err)
181+
spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...)
182+
if err != nil {
183+
return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err)
184+
}
185+
186+
cdiDeviceRequestor, err := cdi.New(
187+
cdi.WithLogger(logger),
188+
cdi.WithSpec(spec.Raw()),
189+
)
190+
if err != nil {
191+
return nil, fmt.Errorf("failed to construct CDI modifier for mode %q: %w", mode, err)
192+
}
193+
194+
modifiers = append(modifiers, cdiDeviceRequestor)
158195
}
159196

160-
return cdiDeviceRequestor, nil
197+
return modifiers, nil
161198
}
162199

163200
type deduplicatedDeviceRequestor struct {

internal/oci/spec.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ type SpecModifier interface {
3131
Modify(*specs.Spec) error
3232
}
3333

34+
// SpecModifiers is a collection of OCI Spec modifiers that can be treated as a
35+
// single modifier.
36+
type SpecModifiers []SpecModifier
37+
38+
var _ SpecModifier = (SpecModifiers)(nil)
39+
3440
// Spec defines the operations to be performed on an OCI specification
3541
//
3642
//go:generate moq -rm -fmt=goimports -stub -out spec_mock.go . Spec
@@ -57,3 +63,16 @@ func NewSpec(logger logger.Interface, args []string) (Spec, error) {
5763

5864
return ociSpec, nil
5965
}
66+
67+
// Modify a spec based on a collection of modifiers.
68+
func (ms SpecModifiers) Modify(s *specs.Spec) error {
69+
for _, m := range ms {
70+
if m == nil {
71+
continue
72+
}
73+
if err := m.Modify(s); err != nil {
74+
return err
75+
}
76+
}
77+
return nil
78+
}

0 commit comments

Comments
 (0)