@@ -20,13 +20,14 @@ import (
2020 "fmt"
2121 "path/filepath"
2222
23+ "github.com/urfave/cli/v2"
24+
2325 "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2426 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine"
2527 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine/containerd"
2628 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine/crio"
2729 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine/docker"
2830 "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/ocihook"
29- "github.com/urfave/cli/v2"
3031)
3132
3233const (
@@ -71,6 +72,11 @@ type config struct {
7172 hookPath string
7273 setAsDefault bool
7374 }
75+
76+ // cdi-specific options
77+ cdi struct {
78+ enabled bool
79+ }
7480}
7581
7682func (m command ) build () * cli.Command {
@@ -141,6 +147,11 @@ func (m command) build() *cli.Command {
141147 Usage : "set the NVIDIA runtime as the default runtime" ,
142148 Destination : & config .nvidiaRuntime .setAsDefault ,
143149 },
150+ & cli.BoolFlag {
151+ Name : "cdi.enabled" ,
152+ Usage : "Enable CDI in the configured runtime" ,
153+ Destination : & config .cdi .enabled ,
154+ },
144155 }
145156
146157 return & configure
@@ -175,6 +186,13 @@ func (m command) validateFlags(c *cli.Context, config *config) error {
175186 }
176187 }
177188
189+ if config .runtime != "containerd" && config .runtime != "docker" {
190+ if config .cdi .enabled {
191+ m .logger .Warningf ("Ignoring cdi.enabled flag for %v" , config .runtime )
192+ }
193+ config .cdi .enabled = false
194+ }
195+
178196 return nil
179197}
180198
@@ -227,6 +245,11 @@ func (m command) configureConfigFile(c *cli.Context, config *config) error {
227245 return fmt .Errorf ("unable to update config: %v" , err )
228246 }
229247
248+ err = enableCDI (config , cfg )
249+ if err != nil {
250+ return fmt .Errorf ("failed to enable CDI in %s: %w" , config .runtime , err )
251+ }
252+
230253 outputPath := config .getOuputConfigPath ()
231254 n , err := cfg .Save (outputPath )
232255 if err != nil {
@@ -277,3 +300,17 @@ func (m *command) configureOCIHook(c *cli.Context, config *config) error {
277300 }
278301 return nil
279302}
303+
304+ // enableCDI enables the use of CDI in the corresponding container engine
305+ func enableCDI (config * config , cfg engine.Interface ) error {
306+ if ! config .cdi .enabled {
307+ return nil
308+ }
309+ switch config .runtime {
310+ case "containerd" :
311+ return cfg .Set ("enable_cdi" , true )
312+ case "docker" :
313+ return cfg .Set ("experimental" , true )
314+ }
315+ return fmt .Errorf ("enabling CDI in %s is not supported" , config .runtime )
316+ }
0 commit comments