Skip to content

Commit df4c87b

Browse files
authored
Merge pull request #838 from cdesiniotis/enable-cdi-toolkit-container
Enable CDI in the container runtime if enabled in the toolkit
2 parents d6c3129 + d8cd543 commit df4c87b

File tree

14 files changed

+604
-66
lines changed

14 files changed

+604
-66
lines changed

cmd/nvidia-ctk-installer/container/container.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ const (
3636

3737
// Options defines the shared options for the CLIs to configure containers runtimes.
3838
type Options struct {
39-
Config string
40-
Socket string
39+
Config string
40+
Socket string
41+
// EnabledCDI indicates whether CDI should be enabled.
42+
EnableCDI bool
4143
RuntimeName string
4244
RuntimeDir string
4345
SetAsDefault bool
@@ -111,6 +113,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
111113
}
112114
}
113115

116+
if o.EnableCDI {
117+
cfg.EnableCDI()
118+
}
119+
114120
return nil
115121
}
116122

cmd/nvidia-ctk-installer/container/runtime/containerd/config_v1_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,51 @@ func TestUpdateV1ConfigWithRuncPresent(t *testing.T) {
410410
}
411411
}
412412

413+
func TestUpdateV1EnableCDI(t *testing.T) {
414+
logger, _ := testlog.NewNullLogger()
415+
const runtimeDir = "/test/runtime/dir"
416+
417+
testCases := []struct {
418+
enableCDI bool
419+
expectedEnableCDIValue interface{}
420+
}{
421+
{},
422+
{
423+
enableCDI: false,
424+
expectedEnableCDIValue: nil,
425+
},
426+
{
427+
enableCDI: true,
428+
expectedEnableCDIValue: true,
429+
},
430+
}
431+
432+
for _, tc := range testCases {
433+
t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) {
434+
o := &container.Options{
435+
EnableCDI: tc.enableCDI,
436+
RuntimeName: "nvidia",
437+
RuntimeDir: runtimeDir,
438+
}
439+
440+
cfg, err := toml.Empty.Load()
441+
require.NoError(t, err)
442+
443+
v1 := &containerd.ConfigV1{
444+
Logger: logger,
445+
Tree: cfg,
446+
RuntimeType: runtimeType,
447+
}
448+
449+
err = o.UpdateConfig(v1)
450+
require.NoError(t, err)
451+
452+
enableCDIValue := v1.GetPath([]string{"plugins", "cri", "containerd", "enable_cdi"})
453+
require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue)
454+
})
455+
}
456+
}
457+
413458
func TestRevertV1Config(t *testing.T) {
414459
logger, _ := testlog.NewNullLogger()
415460
testCases := []struct {

cmd/nvidia-ctk-installer/container/runtime/containerd/config_v2_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,53 @@ func TestUpdateV2ConfigWithRuncPresent(t *testing.T) {
366366
}
367367
}
368368

369+
func TestUpdateV2ConfigEnableCDI(t *testing.T) {
370+
logger, _ := testlog.NewNullLogger()
371+
const runtimeDir = "/test/runtime/dir"
372+
373+
testCases := []struct {
374+
enableCDI bool
375+
expectedEnableCDIValue interface{}
376+
}{
377+
{},
378+
{
379+
enableCDI: false,
380+
expectedEnableCDIValue: nil,
381+
},
382+
{
383+
enableCDI: true,
384+
expectedEnableCDIValue: true,
385+
},
386+
}
387+
388+
for _, tc := range testCases {
389+
t.Run(fmt.Sprintf("%v", tc.enableCDI), func(t *testing.T) {
390+
o := &container.Options{
391+
EnableCDI: tc.enableCDI,
392+
RuntimeName: "nvidia",
393+
RuntimeDir: runtimeDir,
394+
SetAsDefault: false,
395+
}
396+
397+
cfg, err := toml.LoadMap(map[string]interface{}{})
398+
require.NoError(t, err)
399+
400+
v2 := &containerd.Config{
401+
Logger: logger,
402+
Tree: cfg,
403+
RuntimeType: runtimeType,
404+
CRIRuntimePluginName: "io.containerd.grpc.v1.cri",
405+
}
406+
407+
err = o.UpdateConfig(v2)
408+
require.NoError(t, err)
409+
410+
enableCDIValue := cfg.GetPath([]string{"plugins", "io.containerd.grpc.v1.cri", "enable_cdi"})
411+
require.EqualValues(t, tc.expectedEnableCDIValue, enableCDIValue)
412+
})
413+
}
414+
}
415+
369416
func TestRevertV2Config(t *testing.T) {
370417
logger, _ := testlog.NewNullLogger()
371418

cmd/nvidia-ctk-installer/container/runtime/runtime.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/containerd"
2626
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/crio"
2727
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/docker"
28+
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/toolkit"
2829
)
2930

3031
const (
@@ -66,6 +67,12 @@ func Flags(opts *Options) []cli.Flag {
6667
Destination: &opts.RestartMode,
6768
EnvVars: []string{"RUNTIME_RESTART_MODE"},
6869
},
70+
&cli.BoolFlag{
71+
Name: "enable-cdi-in-runtime",
72+
Usage: "Enable CDI in the configured runt ime",
73+
Destination: &opts.EnableCDI,
74+
EnvVars: []string{"RUNTIME_ENABLE_CDI"},
75+
},
6976
&cli.StringFlag{
7077
Name: "host-root",
7178
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",
@@ -98,10 +105,14 @@ func Flags(opts *Options) []cli.Flag {
98105
}
99106

100107
// ValidateOptions checks whether the specified options are valid
101-
func ValidateOptions(opts *Options, runtime string, toolkitRoot string) error {
108+
func ValidateOptions(c *cli.Context, opts *Options, runtime string, toolkitRoot string, to *toolkit.Options) error {
102109
// We set this option here to ensure that it is available in future calls.
103110
opts.RuntimeDir = toolkitRoot
104111

112+
if !c.IsSet("enable-cdi-in-runtime") {
113+
opts.EnableCDI = to.CDI.Enabled
114+
}
115+
105116
// Apply the runtime-specific config changes.
106117
switch runtime {
107118
case containerd.Name:

cmd/nvidia-ctk-installer/container/toolkit/toolkit.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ const (
4444
configFilename = "config.toml"
4545
)
4646

47+
type cdiOptions struct {
48+
Enabled bool
49+
outputDir string
50+
kind string
51+
vendor string
52+
class string
53+
}
54+
4755
type Options struct {
4856
DriverRoot string
4957
DevRoot string
@@ -63,11 +71,8 @@ type Options struct {
6371

6472
ContainerCLIDebug string
6573

66-
cdiEnabled bool
67-
cdiOutputDir string
68-
cdiKind string
69-
cdiVendor string
70-
cdiClass string
74+
// CDI stores the CDI options for the toolkit.
75+
CDI cdiOptions
7176

7277
createDeviceNodes cli.StringSlice
7378

@@ -170,21 +175,21 @@ func Flags(opts *Options) []cli.Flag {
170175
Name: "cdi-enabled",
171176
Aliases: []string{"enable-cdi"},
172177
Usage: "enable the generation of a CDI specification",
173-
Destination: &opts.cdiEnabled,
178+
Destination: &opts.CDI.Enabled,
174179
EnvVars: []string{"CDI_ENABLED", "ENABLE_CDI"},
175180
},
176181
&cli.StringFlag{
177182
Name: "cdi-output-dir",
178183
Usage: "the directory where the CDI output files are to be written. If this is set to '', no CDI specification is generated.",
179184
Value: "/var/run/cdi",
180-
Destination: &opts.cdiOutputDir,
185+
Destination: &opts.CDI.outputDir,
181186
EnvVars: []string{"CDI_OUTPUT_DIR"},
182187
},
183188
&cli.StringFlag{
184189
Name: "cdi-kind",
185190
Usage: "the vendor string to use for the generated CDI specification",
186191
Value: "management.nvidia.com/gpu",
187-
Destination: &opts.cdiKind,
192+
Destination: &opts.CDI.kind,
188193
EnvVars: []string{"CDI_KIND"},
189194
},
190195
&cli.BoolFlag{
@@ -240,19 +245,19 @@ func (t *Installer) ValidateOptions(opts *Options) error {
240245
return fmt.Errorf("invalid --toolkit-root option: %v", t.toolkitRoot)
241246
}
242247

243-
vendor, class := parser.ParseQualifier(opts.cdiKind)
248+
vendor, class := parser.ParseQualifier(opts.CDI.kind)
244249
if err := parser.ValidateVendorName(vendor); err != nil {
245250
return fmt.Errorf("invalid CDI vendor name: %v", err)
246251
}
247252
if err := parser.ValidateClassName(class); err != nil {
248253
return fmt.Errorf("invalid CDI class name: %v", err)
249254
}
250-
opts.cdiVendor = vendor
251-
opts.cdiClass = class
255+
opts.CDI.vendor = vendor
256+
opts.CDI.class = class
252257

253-
if opts.cdiEnabled && opts.cdiOutputDir == "" {
258+
if opts.CDI.Enabled && opts.CDI.outputDir == "" {
254259
t.logger.Warning("Skipping CDI spec generation (no output directory specified)")
255-
opts.cdiEnabled = false
260+
opts.CDI.Enabled = false
256261
}
257262

258263
isDisabled := false
@@ -265,7 +270,7 @@ func (t *Installer) ValidateOptions(opts *Options) error {
265270
break
266271
}
267272
}
268-
if !opts.cdiEnabled && !isDisabled {
273+
if !opts.CDI.Enabled && !isDisabled {
269274
t.logger.Info("disabling device node creation since --cdi-enabled=false")
270275
isDisabled = true
271276
}
@@ -698,7 +703,7 @@ func (t *Installer) createDeviceNodes(opts *Options) error {
698703

699704
// generateCDISpec generates a CDI spec for use in management containers
700705
func (t *Installer) generateCDISpec(opts *Options, nvidiaCDIHookPath string) error {
701-
if !opts.cdiEnabled {
706+
if !opts.CDI.Enabled {
702707
return nil
703708
}
704709
t.logger.Info("Generating CDI spec for management containers")
@@ -708,8 +713,8 @@ func (t *Installer) generateCDISpec(opts *Options, nvidiaCDIHookPath string) err
708713
nvcdi.WithDriverRoot(opts.DriverRootCtrPath),
709714
nvcdi.WithDevRoot(opts.DevRootCtrPath),
710715
nvcdi.WithNVIDIACDIHookPath(nvidiaCDIHookPath),
711-
nvcdi.WithVendor(opts.cdiVendor),
712-
nvcdi.WithClass(opts.cdiClass),
716+
nvcdi.WithVendor(opts.CDI.vendor),
717+
nvcdi.WithClass(opts.CDI.class),
713718
)
714719
if err != nil {
715720
return fmt.Errorf("failed to create CDI library for management containers: %v", err)
@@ -734,7 +739,7 @@ func (t *Installer) generateCDISpec(opts *Options, nvidiaCDIHookPath string) err
734739
if err != nil {
735740
return fmt.Errorf("failed to generate CDI name for management containers: %v", err)
736741
}
737-
err = spec.Save(filepath.Join(opts.cdiOutputDir, name))
742+
err = spec.Save(filepath.Join(opts.CDI.outputDir, name))
738743
if err != nil {
739744
return fmt.Errorf("failed to save CDI spec for management containers: %v", err)
740745
}

cmd/nvidia-ctk-installer/container/toolkit/toolkit_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,11 @@ kind: example.com/class
124124
options := Options{
125125
DriverRoot: "/host/driver/root",
126126
DriverRootCtrPath: filepath.Join(moduleRoot, "testdata", "lookup", tc.hostRoot),
127-
cdiEnabled: tc.cdiEnabled,
128-
cdiOutputDir: cdiOutputDir,
129-
cdiKind: "example.com/class",
127+
CDI: cdiOptions{
128+
Enabled: tc.cdiEnabled,
129+
outputDir: cdiOutputDir,
130+
kind: "example.com/class",
131+
},
130132
}
131133

132134
ti := NewInstaller(

cmd/nvidia-ctk-installer/main.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type options struct {
3838
runtimeArgs string
3939
root string
4040
pidFile string
41+
sourceRoot string
4142

4243
toolkitOptions toolkit.Options
4344
runtimeOptions runtime.Options
@@ -141,6 +142,13 @@ func (a app) build() *cli.App {
141142
Destination: &options.root,
142143
EnvVars: []string{"ROOT"},
143144
},
145+
&cli.StringFlag{
146+
Name: "source-root",
147+
Value: "/",
148+
Usage: "The folder where the required toolkit artifacts can be found",
149+
Destination: &options.sourceRoot,
150+
EnvVars: []string{"SOURCE_ROOT"},
151+
},
144152
&cli.StringFlag{
145153
Name: "pid-file",
146154
Value: defaultPidFile,
@@ -159,12 +167,13 @@ func (a app) build() *cli.App {
159167
func (a *app) Before(c *cli.Context, o *options) error {
160168
a.toolkit = toolkit.NewInstaller(
161169
toolkit.WithLogger(a.logger),
170+
toolkit.WithSourceRoot(o.sourceRoot),
162171
toolkit.WithToolkitRoot(o.toolkitRoot()),
163172
)
164173
return a.validateFlags(c, o)
165174
}
166175

167-
func (a *app) validateFlags(_ *cli.Context, o *options) error {
176+
func (a *app) validateFlags(c *cli.Context, o *options) error {
168177
if o.root == "" {
169178
return fmt.Errorf("the install root must be specified")
170179
}
@@ -178,7 +187,7 @@ func (a *app) validateFlags(_ *cli.Context, o *options) error {
178187
if err := a.toolkit.ValidateOptions(&o.toolkitOptions); err != nil {
179188
return err
180189
}
181-
if err := runtime.ValidateOptions(&o.runtimeOptions, o.runtime, o.toolkitRoot()); err != nil {
190+
if err := runtime.ValidateOptions(c, &o.runtimeOptions, o.runtime, o.toolkitRoot(), &o.toolkitOptions); err != nil {
182191
return err
183192
}
184193
return nil

0 commit comments

Comments
 (0)